Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2023 Haotian Liu | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from transformers.generation.utils import GenerateNonBeamOutput | |
from transformers.utils import logging, is_accelerate_available | |
from transformers.generation.configuration_utils import GenerationConfig | |
from transformers.generation.logits_process import ( | |
LogitsProcessorList, | |
) | |
from transformers.generation.streamers import BaseStreamer | |
from transformers.generation.stopping_criteria import ( | |
StoppingCriteriaList, | |
) | |
from transformers.utils import ModelOutput, logging | |
import os | |
logger = logging.get_logger(__name__) | |
import collections | |
import gc | |
import itertools | |
import os | |
import re | |
import shutil | |
import tempfile | |
from transformers import PreTrainedModel | |
from transformers.integrations import is_deepspeed_zero3_enabled | |
from transformers.pytorch_utils import id_tensor_storage | |
from transformers.modeling_utils import ( | |
is_fsdp_enabled, is_local_dist_rank_0, | |
load_state_dict, set_initialized_submodules, | |
_load_state_dict_into_model, | |
_load_state_dict_into_meta_model, | |
expand_device_map, get_disk_only_shard_files, | |
get_disk_only_shard_files, | |
) | |
if is_accelerate_available(): | |
from accelerate.utils import ( | |
find_tied_parameters, | |
load_offloaded_weights, | |
save_offload_index, | |
set_module_tensor_to_device, | |
) | |
from transformers.utils import logging | |
from dataclasses import dataclass | |
PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning." | |
class GenerateDecoderOnlyOutput(ModelOutput): | |
""" | |
Outputs of decoder-only generation models, when using non-beam methods. | |
Args: | |
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter | |
if all batches finished early due to the `eos_token_id`. | |
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): | |
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for | |
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. | |
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True` is passed or when `config.output_logits=True`): | |
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) | |
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for | |
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. | |
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. | |
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | |
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of | |
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. | |
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. | |
Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value | |
tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape | |
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if | |
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, | |
encoder_sequence_length, embed_size_per_head)`. | |
""" | |
sequences: torch.LongTensor = None | |
scores: Optional[Tuple[torch.FloatTensor]] = None | |
logits: Optional[Tuple[torch.FloatTensor]] = None | |
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None | |
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False): | |
# Convert old format to new format if needed from a PyTorch state_dict | |
old_keys = [] | |
new_keys = [] | |
for key in state_dict.keys(): | |
new_key = None | |
if "gamma" in key and ("vision_tower.vision_tower" not in key and "dav2_model" not in key): | |
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight")) | |
new_key = key.replace("gamma", "weight") | |
if "beta" in key and "vision_tower.vision_tower" not in key: | |
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias")) | |
new_key = key.replace("beta", "bias") | |
if new_key: | |
old_keys.append(key) | |
new_keys.append(new_key) | |
for old_key, new_key in zip(old_keys, new_keys): | |
state_dict[new_key] = state_dict.pop(old_key) | |
# copy state_dict so _load_from_state_dict can modify it | |
metadata = getattr(state_dict, "_metadata", None) | |
state_dict = state_dict.copy() | |
if metadata is not None: | |
state_dict._metadata = metadata | |
error_msgs = [] | |
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants | |
# so we need to apply the function recursively. | |
def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False): | |
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers | |
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) | |
# Parameters of module and children will start with prefix. We can exit early if there are none in this | |
# state_dict | |
if len([key for key in state_dict if key.startswith(prefix)]) > 0: | |
if is_deepspeed_zero3_enabled(): | |
import deepspeed | |
# In sharded models, each shard has only part of the full state_dict, so only gather | |
# parameters that are in the current state_dict. | |
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) | |
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] | |
if len(params_to_gather) > 0: | |
# because zero3 puts placeholders in model params, this context | |
# manager gathers (unpartitions) the params of the current layer, then loads from | |
# the state dict and then re-partitions them again | |
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): | |
if torch.distributed.get_rank() == 0: | |
module._load_from_state_dict(*args) | |
else: | |
module._load_from_state_dict(*args) | |
for name, child in module._modules.items(): | |
if child is not None: | |
load(child, state_dict, prefix + name + ".", assign_to_params_buffers) | |
load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers) | |
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so | |
# it's safe to delete it. | |
del state_dict | |
return error_msgs | |
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): | |
""" | |
Checks if `model_to_load` supports param buffer assignment (such | |
as when loading in empty weights) by first checking | |
if the model explicitly disables it, then by ensuring that the state dict keys | |
are a subset of the model's parameters. | |
Note: We fully disable this if we are using `deepspeed` | |
""" | |
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: | |
return False | |
if is_deepspeed_zero3_enabled(): | |
return False | |
# Some models explicitly do not support param buffer assignment | |
if not getattr(model_to_load, "_supports_param_buffer_assignment", True): | |
logger.debug( | |
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" | |
) | |
return False | |
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype | |
first_key = list(model_to_load.state_dict().keys())[0] | |
if start_prefix + first_key in state_dict: | |
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype | |
# For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`) | |
return False | |
class BaseCausalLM(PreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
def _sample( | |
self, | |
input_ids: torch.LongTensor, | |
logits_processor: LogitsProcessorList, | |
stopping_criteria: StoppingCriteriaList, | |
generation_config: GenerationConfig, | |
synced_gpus: bool, | |
streamer: Optional["BaseStreamer"], | |
logits_warper: Optional[LogitsProcessorList] = None, | |
**model_kwargs, | |
) -> Union[GenerateNonBeamOutput, torch.LongTensor]: | |
r""" | |
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and | |
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. | |
Parameters: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
The sequence used as a prompt for the generation. | |
logits_processor (`LogitsProcessorList`): | |
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | |
used to modify the prediction scores of the language modeling head applied at each generation step. | |
stopping_criteria (`StoppingCriteriaList`): | |
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] | |
used to tell if the generation loop should stop. | |
generation_config ([`~generation.GenerationConfig`]): | |
The generation configuration to be used as parametrization of the decoding method. | |
synced_gpus (`bool`): | |
Whether to continue running the while loop until max_length (needed for ZeRO stage 3) | |
streamer (`BaseStreamer`, *optional*): | |
Streamer object that will be used to stream the generated sequences. Generated tokens are passed | |
through `streamer.put(token_ids)` and the streamer is responsible for any further processing. | |
logits_warper (`LogitsProcessorList`, *optional*): | |
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used | |
to warp the prediction score distribution of the language modeling head applied before multinomial | |
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in | |
`generation_config`) | |
model_kwargs: | |
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is | |
an encoder-decoder model the kwargs should include `encoder_outputs`. | |
Return: | |
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: | |
A `torch.LongTensor` containing the generated tokens (default behaviour) or a | |
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and | |
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if | |
`model.config.is_encoder_decoder=True`. | |
""" | |
# init values | |
pad_token_id = generation_config.pad_token_id | |
output_attentions = generation_config.output_attentions | |
output_hidden_states = generation_config.output_hidden_states | |
output_scores = generation_config.output_scores | |
output_logits = generation_config.output_logits | |
return_dict_in_generate = generation_config.return_dict_in_generate | |
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) | |
do_sample = generation_config.do_sample | |
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): | |
raise ValueError( | |
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " | |
f"{logits_warper})." | |
) | |
# init attention / hidden states / scores tuples | |
scores = () if (return_dict_in_generate and output_scores) else None | |
raw_logits = () if (return_dict_in_generate and output_logits) else None | |
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None | |
cross_attentions = () if (return_dict_in_generate and output_attentions) else None | |
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None | |
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states | |
if return_dict_in_generate and self.config.is_encoder_decoder: | |
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None | |
encoder_hidden_states = ( | |
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None | |
) | |
# keep track of which sequences are already finished | |
batch_size = input_ids.shape[0] | |
this_peer_finished = False | |
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) | |
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) | |
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): | |
# prepare model inputs | |
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) | |
# forward pass to get next token | |
outputs = self( | |
**model_inputs, | |
return_dict=True, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
if synced_gpus and this_peer_finished: | |
continue # don't waste resources running the code we don't need | |
next_token_logits = outputs.logits[:, -1, :] | |
# pre-process distribution | |
next_token_scores = logits_processor(input_ids, next_token_logits) | |
if do_sample: | |
next_token_scores = logits_warper(input_ids, next_token_scores) | |
# Store scores, attentions and hidden_states when required | |
if return_dict_in_generate: | |
if output_scores: | |
scores += (next_token_scores,) | |
if output_logits: | |
raw_logits += (next_token_logits,) | |
if output_attentions: | |
decoder_attentions += ( | |
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) | |
) | |
if self.config.is_encoder_decoder: | |
cross_attentions += (outputs.cross_attentions,) | |
if output_hidden_states: | |
decoder_hidden_states += ( | |
(outputs.decoder_hidden_states,) | |
if self.config.is_encoder_decoder | |
else (outputs.hidden_states,) | |
) | |
probs = nn.functional.softmax(next_token_scores, dim=-1) | |
# token selection | |
if do_sample: | |
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | |
else: | |
next_tokens = torch.argmax(next_token_scores, dim=-1) | |
# finished sentences should have their next token be a padding token | |
if has_eos_stopping_criteria: | |
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | |
# update generated ids, model inputs, and length for next step | |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | |
if streamer is not None: | |
streamer.put(next_tokens.cpu()) | |
model_kwargs = self._update_model_kwargs_for_generation( | |
outputs, | |
model_kwargs, | |
is_encoder_decoder=self.config.is_encoder_decoder, | |
) | |
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) | |
this_peer_finished = unfinished_sequences.max() == 0 | |
if streamer is not None: | |
streamer.end() | |
if return_dict_in_generate: | |
return GenerateDecoderOnlyOutput( | |
sequences=input_ids, | |
scores=scores, | |
logits=raw_logits, | |
attentions=decoder_attentions, | |
hidden_states=decoder_hidden_states, | |
past_key_values=model_kwargs.get("past_key_values"), | |
) | |
else: | |
return input_ids | |
def _load_pretrained_model( | |
cls, | |
model, | |
state_dict, | |
loaded_keys, | |
resolved_archive_file, | |
pretrained_model_name_or_path, | |
ignore_mismatched_sizes=False, | |
sharded_metadata=None, | |
_fast_init=True, | |
low_cpu_mem_usage=False, | |
device_map=None, | |
offload_folder=None, | |
offload_state_dict=None, | |
dtype=None, | |
hf_quantizer=None, | |
keep_in_fp32_modules=None, | |
gguf_path=None, | |
): | |
is_safetensors = False | |
is_quantized = hf_quantizer is not None | |
state_dict_folder = None | |
state_dict_index = None | |
if device_map is not None and "disk" in device_map.values(): | |
archive_file = ( | |
resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file | |
) | |
is_safetensors = archive_file.endswith(".safetensors") | |
if offload_folder is None and not is_safetensors: | |
raise ValueError( | |
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" | |
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using" | |
" offers the weights in this format." | |
) | |
if offload_folder is not None: | |
os.makedirs(offload_folder, exist_ok=True) | |
if offload_state_dict is None: | |
offload_state_dict = True | |
is_sharded_safetensors = is_safetensors and sharded_metadata is not None | |
for key, param in model.state_dict().items(): | |
if param.device == torch.device("meta"): | |
try: | |
set_module_tensor_to_device( | |
model, key, "cuda", torch.empty(*param.size(), dtype=dtype) | |
) | |
except: | |
pass | |
# tie the model weights before retrieving the state_dict | |
model.tie_weights() | |
# Retrieve missing & unexpected_keys | |
model_state_dict = model.state_dict() | |
expected_keys = list(model_state_dict.keys()) | |
prefix = model.base_model_prefix | |
def _fix_key(key): | |
if "beta" in key and "vision_tower.vision_tower" not in key: | |
return key.replace("beta", "bias") | |
if "gamma" in key and ("vision_tower.vision_tower" not in key and "dav2_model" not in key): | |
return key.replace("gamma", "weight") | |
return key | |
original_loaded_keys = loaded_keys | |
loaded_keys = [_fix_key(key) for key in loaded_keys] | |
if len(prefix) > 0: | |
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) | |
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) | |
else: | |
has_prefix_module = False | |
expects_prefix_module = False | |
# key re-naming operations are never done on the keys | |
# that are loaded, but always on the keys of the newly initialized model | |
remove_prefix_from_model = not has_prefix_module and expects_prefix_module | |
add_prefix_to_model = has_prefix_module and not expects_prefix_module | |
if remove_prefix_from_model: | |
_prefix = f"{prefix}." | |
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)] | |
expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys] | |
elif add_prefix_to_model: | |
expected_keys = [".".join([prefix, s]) for s in expected_keys] | |
missing_keys = sorted(set(expected_keys) - set(loaded_keys)) | |
unexpected_keys = set(loaded_keys) - set(expected_keys) | |
# Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model | |
# buffers | |
model_buffers = {n for n, _ in model.named_buffers()} | |
if remove_prefix_from_model: | |
model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers} | |
elif add_prefix_to_model: | |
model_buffers = {".".join([prefix, key]) for key in model_buffers} | |
unexpected_keys = sorted(unexpected_keys - model_buffers) | |
model.tie_weights() | |
if device_map is None and not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): | |
ptrs = collections.defaultdict(list) | |
for name, tensor in model.state_dict().items(): | |
id_tensor = id_tensor_storage(tensor) | |
ptrs[id_tensor].append(name) | |
# These are all the pointers of shared tensors. | |
tied_params = [names for _, names in ptrs.items() if len(names) > 1] | |
else: | |
# id function doesn't work for meta tensor so we need this function | |
tied_params = find_tied_parameters(model) | |
for group in tied_params: | |
if remove_prefix_from_model: | |
group = [key[len(_prefix) :] if key.startswith(_prefix) else key for key in group] | |
elif add_prefix_to_model: | |
group = [".".join([prefix, key]) for key in group] | |
missing_in_group = [k for k in missing_keys if k in group] | |
if len(missing_in_group) > 0 and len(missing_in_group) < len(group): | |
missing_keys = [k for k in missing_keys if k not in missing_in_group] | |
# Some models may have keys that are not in the state by design, removing them before needlessly warning | |
# the user. | |
if cls._keys_to_ignore_on_load_missing is not None: | |
for pat in cls._keys_to_ignore_on_load_missing: | |
missing_keys = [k for k in missing_keys if re.search(pat, k) is None] | |
if cls._keys_to_ignore_on_load_unexpected is not None: | |
for pat in cls._keys_to_ignore_on_load_unexpected: | |
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] | |
if hf_quantizer is not None: | |
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) | |
# retrieve weights on meta device and put them back on CPU. | |
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step | |
if low_cpu_mem_usage: | |
for key in missing_keys: | |
if key in list(model_state_dict.keys()): | |
key = key | |
elif f"{prefix}.{key}" in list(model_state_dict.keys()): | |
key = f"{prefix}.{key}" | |
elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()): | |
key = ".".join(key.split(".")[1:]) | |
param = model_state_dict[key] | |
# upcast in fp32 if any | |
target_dtype = dtype | |
if ( | |
keep_in_fp32_modules is not None | |
and dtype == torch.float16 | |
and any( | |
module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules | |
) | |
): | |
target_dtype = torch.float32 | |
if param.device == torch.device("meta"): | |
value = torch.empty(*param.size(), dtype=target_dtype) | |
if ( | |
not is_quantized | |
or getattr(hf_quantizer, "requires_parameters_quantization", False) | |
or not hf_quantizer.check_quantized_param( | |
model, param_value=value, param_name=key, state_dict={} | |
) | |
): | |
set_module_tensor_to_device(model, key, "cpu", value) | |
else: | |
hf_quantizer.create_quantized_param(model, value, key, "cpu", state_dict, unexpected_keys) | |
# retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights. | |
if _fast_init: | |
if not ignore_mismatched_sizes: | |
if remove_prefix_from_model: | |
_loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] | |
elif add_prefix_to_model: | |
_loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] | |
else: | |
_loaded_keys = loaded_keys | |
not_initialized_submodules = set_initialized_submodules(model, _loaded_keys) | |
# If we're about to tie the output embeds to the input embeds we don't need to init them | |
if hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings: | |
output_embeddings = model.get_output_embeddings() | |
if output_embeddings is not None: | |
# Still need to initialize if there is a bias term since biases are not tied. | |
if not hasattr(output_embeddings, "bias") or output_embeddings.bias is None: | |
output_embeddings._is_hf_initialized = True | |
else: | |
not_initialized_submodules = dict(model.named_modules()) | |
# This will only initialize submodules that are not marked as initialized by the line above. | |
if is_deepspeed_zero3_enabled() and not is_quantized: | |
import deepspeed | |
not_initialized_parameters = list( | |
set( | |
itertools.chain.from_iterable( | |
submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values() | |
) | |
) | |
) | |
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0): | |
model.apply(model._initialize_weights) | |
else: | |
model.apply(model._initialize_weights) | |
# Set some modules to fp32 if any | |
if keep_in_fp32_modules is not None: | |
for name, param in model.named_parameters(): | |
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): | |
# param = param.to(torch.float32) does not work here as only in the local scope. | |
param.data = param.data.to(torch.float32) | |
# Make sure we are able to load base models as well as derived models (with heads) | |
start_prefix = "" | |
model_to_load = model | |
if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module: | |
start_prefix = cls.base_model_prefix + "." | |
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: | |
model_to_load = getattr(model, cls.base_model_prefix) | |
base_model_expected_keys = list(model_to_load.state_dict().keys()) | |
if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys): | |
raise ValueError( | |
"The state dictionary of the model you are trying to load is corrupted. Are you sure it was " | |
"properly saved?" | |
) | |
if device_map is not None: | |
device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()} | |
def _find_mismatched_keys( | |
state_dict, | |
model_state_dict, | |
loaded_keys, | |
add_prefix_to_model, | |
remove_prefix_from_model, | |
ignore_mismatched_sizes, | |
): | |
mismatched_keys = [] | |
if ignore_mismatched_sizes: | |
for checkpoint_key in loaded_keys: | |
# If the checkpoint is sharded, we may not have the key here. | |
if checkpoint_key not in state_dict: | |
continue | |
model_key = checkpoint_key | |
if remove_prefix_from_model: | |
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. | |
model_key = f"{prefix}.{checkpoint_key}" | |
elif add_prefix_to_model: | |
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. | |
model_key = ".".join(checkpoint_key.split(".")[1:]) | |
if ( | |
model_key in model_state_dict | |
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape | |
): | |
if ( | |
state_dict[checkpoint_key].shape[-1] == 1 | |
and state_dict[checkpoint_key].numel() * 2 == model_state_dict[model_key].numel() | |
): | |
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. | |
# Without matching with module type or paramter type it seems like a practical way to detect valid 4bit weights. | |
pass | |
else: | |
mismatched_keys.append( | |
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) | |
) | |
del state_dict[checkpoint_key] | |
return mismatched_keys | |
if resolved_archive_file is not None: | |
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1]) | |
else: | |
folder = None | |
if device_map is not None and is_safetensors: | |
param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) | |
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" | |
if sharded_metadata is None: | |
archive_file = ( | |
resolved_archive_file[0] | |
if isinstance(resolved_archive_file, (list, tuple)) | |
else resolved_archive_file | |
) | |
weight_map = {p: archive_file for p in original_loaded_keys} | |
else: | |
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} | |
offload_index = { | |
p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} | |
for p, f in weight_map.items() | |
if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk" | |
} | |
else: | |
offload_index = None | |
if state_dict is not None: | |
# Whole checkpoint | |
mismatched_keys = _find_mismatched_keys( | |
state_dict, | |
model_state_dict, | |
original_loaded_keys, | |
add_prefix_to_model, | |
remove_prefix_from_model, | |
ignore_mismatched_sizes, | |
) | |
# For GGUF models `state_dict` is never set to None as the state dict is always small | |
if gguf_path: | |
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( | |
model_to_load, | |
state_dict, | |
loaded_keys, | |
start_prefix, | |
expected_keys, | |
device_map=device_map, | |
offload_folder=offload_folder, | |
offload_index=offload_index, | |
state_dict_folder=state_dict_folder, | |
state_dict_index=state_dict_index, | |
dtype=dtype, | |
hf_quantizer=hf_quantizer, | |
is_safetensors=is_safetensors, | |
keep_in_fp32_modules=keep_in_fp32_modules, | |
unexpected_keys=unexpected_keys, | |
) | |
else: | |
# Sharded checkpoint or whole but low_cpu_mem_usage==True | |
assign_to_params_buffers = check_support_param_buffer_assignment( | |
model_to_load, state_dict, start_prefix | |
) | |
error_msgs = _load_state_dict_into_model( | |
model_to_load, state_dict, start_prefix, assign_to_params_buffers | |
) | |
else: | |
# This should always be a list but, just to be sure. | |
if not isinstance(resolved_archive_file, list): | |
resolved_archive_file = [resolved_archive_file] | |
error_msgs = [] | |
mismatched_keys = [] | |
if not is_safetensors: | |
offload_index = {} if device_map is not None and "disk" in device_map.values() else None | |
if offload_state_dict: | |
state_dict_folder = tempfile.mkdtemp() | |
state_dict_index = {} | |
else: | |
state_dict_folder = None | |
state_dict_index = None | |
if is_sharded_safetensors: | |
disk_only_shard_files = get_disk_only_shard_files( | |
device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix | |
) | |
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files] | |
else: | |
disk_only_shard_files = [] | |
if len(resolved_archive_file) > 1: | |
resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") | |
assign_to_params_buffers = None | |
for shard_file in resolved_archive_file: | |
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. | |
if shard_file in disk_only_shard_files: | |
continue | |
state_dict = load_state_dict(shard_file, is_quantized=is_quantized) | |
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not | |
# matching the weights in the model. | |
mismatched_keys += _find_mismatched_keys( | |
state_dict, | |
model_state_dict, | |
original_loaded_keys, | |
add_prefix_to_model, | |
remove_prefix_from_model, | |
ignore_mismatched_sizes, | |
) | |
if low_cpu_mem_usage: | |
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: | |
for key, param in model_to_load.state_dict().items(): | |
if param.device == torch.device("meta"): | |
set_module_tensor_to_device( | |
model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) | |
) | |
else: | |
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( | |
model_to_load, | |
state_dict, | |
loaded_keys, | |
start_prefix, | |
expected_keys, | |
device_map=device_map, | |
offload_folder=offload_folder, | |
offload_index=offload_index, | |
state_dict_folder=state_dict_folder, | |
state_dict_index=state_dict_index, | |
dtype=dtype, | |
hf_quantizer=hf_quantizer, | |
is_safetensors=is_safetensors, | |
keep_in_fp32_modules=keep_in_fp32_modules, | |
unexpected_keys=unexpected_keys, | |
) | |
error_msgs += new_error_msgs | |
else: | |
# Sharded checkpoint or whole but low_cpu_mem_usage==True | |
if assign_to_params_buffers is None: | |
assign_to_params_buffers = check_support_param_buffer_assignment( | |
model_to_load, state_dict, start_prefix | |
) | |
error_msgs += _load_state_dict_into_model( | |
model_to_load, state_dict, start_prefix, assign_to_params_buffers | |
) | |
# force memory release | |
del state_dict | |
gc.collect() | |
if offload_index is not None and len(offload_index) > 0: | |
if model != model_to_load: | |
# We need to add the prefix of the base model | |
prefix = cls.base_model_prefix | |
if not is_safetensors: | |
for weight_name in offload_index: | |
shutil.move( | |
os.path.join(offload_folder, f"{weight_name}.dat"), | |
os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"), | |
) | |
offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()} | |
if not is_safetensors: | |
save_offload_index(offload_index, offload_folder) | |
offload_index = None | |
if offload_state_dict: | |
# Load back temporarily offloaded state dict | |
load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder) | |
shutil.rmtree(state_dict_folder) | |
if len(error_msgs) > 0: | |
error_msg = "\n\t".join(error_msgs) | |
if "size mismatch" in error_msg: | |
error_msg += ( | |
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." | |
) | |
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") | |
if len(unexpected_keys) > 0: | |
archs = [] if model.config.architectures is None else model.config.architectures | |
warner = logger.warning if model.__class__.__name__ in archs else logger.info | |
warner( | |
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" | |
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" | |
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" | |
" with another architecture (e.g. initializing a BertForSequenceClassification model from a" | |
" BertForPreTraining model).\n- This IS NOT expected if you are initializing" | |
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" | |
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." | |
) | |
else: | |
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") | |
if len(missing_keys) > 0: | |
logger.warning( | |
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" | |
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" | |
" TRAIN this model on a down-stream task to be able to use it for predictions and inference." | |
) | |
elif len(mismatched_keys) == 0: | |
logger.info( | |
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" | |
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" | |
f" was trained on, you can already use {model.__class__.__name__} for predictions without further" | |
" training." | |
) | |
if len(mismatched_keys) > 0: | |
mismatched_warning = "\n".join( | |
[ | |
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" | |
for key, shape1, shape2 in mismatched_keys | |
] | |
) | |
logger.warning( | |
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" | |
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" | |
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" | |
" to use it for predictions and inference." | |
) | |
return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs |