diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..8042d6c1b706ec2c71d48e94045826f23422b8f6 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +image_0.png filter=lfs diff=lfs merge=lfs -text +image_1.png filter=lfs diff=lfs merge=lfs -text +image_2.png filter=lfs diff=lfs merge=lfs -text +image_3.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..79d1facbe641aa7da396a0875b650788e3edef65 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +step_* +epoch_* diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..138fb7c8892a2a9debc7498c828a471ea0ba7885 --- /dev/null +++ b/README.md @@ -0,0 +1,21 @@ + +--- +license: creativeml-openrail-m +base_model: runwayml/stable-diffusion-v1-5 +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +- lora +inference: true +--- + +# LoRA text2image fine-tuning - https://huggingface.co/monroex/pokemon-lora-test +These are LoRA adaption weights for runwayml/stable-diffusion-v1-5. The weights were fine-tuned on the lambdalabs/pokemon-blip-captions dataset. You can find some example images in the following. + +![img_0](./image_0.png) +![img_1](./image_1.png) +![img_2](./image_2.png) +![img_3](./image_3.png) + diff --git a/checkpoint-1000/latest b/checkpoint-1000/latest new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8602be034ae63f23b293f8d037fa7afa0c54 --- /dev/null +++ b/checkpoint-1000/latest @@ -0,0 +1 @@ +pytorch_model \ No newline at end of file diff --git a/checkpoint-1000/pytorch_model/mp_rank_00_model_states.pt b/checkpoint-1000/pytorch_model/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..c2b8e205d6b364b1793c4ff67231ab24d7721e76 --- /dev/null +++ b/checkpoint-1000/pytorch_model/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:169872da449b31958b28b573994e2025fe26e59b00a678b9942c9a7b506c38af +size 1658603 diff --git a/checkpoint-1000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-1000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..7724a44a2f473441256c482146110b9b48def81d --- /dev/null +++ b/checkpoint-1000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d05ab7892bb5883e4b0674ef5735a79bb1a3f10dc82fc2ffff79e08f2a773d2 +size 9586591 diff --git a/checkpoint-1000/random_states_0.pkl b/checkpoint-1000/random_states_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7b251996efbf07d351cf9ada672c99ebb69025ac --- /dev/null +++ b/checkpoint-1000/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d88a86e0bada1d71b0585a13d643dd7760b1660a130ea8d211bb81254d56aa6 +size 14631 diff --git a/checkpoint-1000/scheduler.bin b/checkpoint-1000/scheduler.bin new file mode 100644 index 0000000000000000000000000000000000000000..c2fa338af0c2f8be6cf901e23600e82438c0eb40 --- /dev/null +++ b/checkpoint-1000/scheduler.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06b857a2920eaf1d24ea94339947967e2306d367961eb47bdb2ecd5a3b7942a0 +size 559 diff --git a/checkpoint-1000/zero_to_fp32.py b/checkpoint-1000/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..e5249853c89179a3d5ad212ead8c464d6c979df5 --- /dev/null +++ b/checkpoint-1000/zero_to_fp32.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_optim_files(checkpoint_dir): + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_model_state(file): + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = { + k: v.float() + for k, + v in state_dict["module"].items() if k in buffer_names + } + param_shapes = state_dict[PARAM_SHAPES] + + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" + ) + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/checkpoint-1500/latest b/checkpoint-1500/latest new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8602be034ae63f23b293f8d037fa7afa0c54 --- /dev/null +++ b/checkpoint-1500/latest @@ -0,0 +1 @@ +pytorch_model \ No newline at end of file diff --git a/checkpoint-1500/pytorch_model/mp_rank_00_model_states.pt b/checkpoint-1500/pytorch_model/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..4bf05964d037244dc62a0e188d332b950972e029 --- /dev/null +++ b/checkpoint-1500/pytorch_model/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5cb7794dd4d96beae4e8f680188df4d41715ff8c26d4d6612af1536712c4a9fa +size 1658603 diff --git a/checkpoint-1500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-1500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..3068b7a440478105c63b010ef12698e51a91858e --- /dev/null +++ b/checkpoint-1500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f883f58fef49ed13d62179896d73153b0d5eb670821292b70cc5650a7e8dbc3 +size 9586591 diff --git a/checkpoint-1500/random_states_0.pkl b/checkpoint-1500/random_states_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..106797b8e1ead61a272eea95b2605c94fa6c8651 --- /dev/null +++ b/checkpoint-1500/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a464abfcec8f339c480f0d0b00298e843509f3b08dc0726079885ddd038c192 +size 14631 diff --git a/checkpoint-1500/scheduler.bin b/checkpoint-1500/scheduler.bin new file mode 100644 index 0000000000000000000000000000000000000000..11a76d99d4e952fa711d92fda02e88dfa48d8eea --- /dev/null +++ b/checkpoint-1500/scheduler.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8bde3ddbf48d5579eff3e63e912d80d76402de061b33f5e67a61dfb8be09c0e3 +size 559 diff --git a/checkpoint-1500/zero_to_fp32.py b/checkpoint-1500/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..e5249853c89179a3d5ad212ead8c464d6c979df5 --- /dev/null +++ b/checkpoint-1500/zero_to_fp32.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_optim_files(checkpoint_dir): + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_model_state(file): + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = { + k: v.float() + for k, + v in state_dict["module"].items() if k in buffer_names + } + param_shapes = state_dict[PARAM_SHAPES] + + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" + ) + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/checkpoint-2000/latest b/checkpoint-2000/latest new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8602be034ae63f23b293f8d037fa7afa0c54 --- /dev/null +++ b/checkpoint-2000/latest @@ -0,0 +1 @@ +pytorch_model \ No newline at end of file diff --git a/checkpoint-2000/pytorch_model/mp_rank_00_model_states.pt b/checkpoint-2000/pytorch_model/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..2b6a8cc40239dab3d50f08bbcb4acb92b13466d0 --- /dev/null +++ b/checkpoint-2000/pytorch_model/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b3b0642b2823e050cda858d93dde428935a5718a75fb05c2d1a44e967f255be +size 1658603 diff --git a/checkpoint-2000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-2000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..cdf28d774d440cb6f58c9664ea38f8c1534f5c31 --- /dev/null +++ b/checkpoint-2000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb991cf4563b45ca63cc5cd171ca61fba868175813f186ffbc3b76b8a3d65c74 +size 9586591 diff --git a/checkpoint-2000/random_states_0.pkl b/checkpoint-2000/random_states_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..46260dad27d5e584233095609fbaff672925da82 --- /dev/null +++ b/checkpoint-2000/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8899469d0b4c825f611ea7567cdaedf26de7dc0b2597ea605010c3f416667e16 +size 14631 diff --git a/checkpoint-2000/scheduler.bin b/checkpoint-2000/scheduler.bin new file mode 100644 index 0000000000000000000000000000000000000000..3224df2cde6270efa7a6607c95152517e1ad5eb2 --- /dev/null +++ b/checkpoint-2000/scheduler.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52ff06f0d1816bad2a0f3a8d09f4c5cb3fab059ddb8a606f334a5f5f83a64702 +size 559 diff --git a/checkpoint-2000/zero_to_fp32.py b/checkpoint-2000/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..e5249853c89179a3d5ad212ead8c464d6c979df5 --- /dev/null +++ b/checkpoint-2000/zero_to_fp32.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_optim_files(checkpoint_dir): + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_model_state(file): + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = { + k: v.float() + for k, + v in state_dict["module"].items() if k in buffer_names + } + param_shapes = state_dict[PARAM_SHAPES] + + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" + ) + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/checkpoint-2500/latest b/checkpoint-2500/latest new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8602be034ae63f23b293f8d037fa7afa0c54 --- /dev/null +++ b/checkpoint-2500/latest @@ -0,0 +1 @@ +pytorch_model \ No newline at end of file diff --git a/checkpoint-2500/pytorch_model/mp_rank_00_model_states.pt b/checkpoint-2500/pytorch_model/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..4c18255ec99a4609831ee12c2a314d28f2cca559 --- /dev/null +++ b/checkpoint-2500/pytorch_model/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4016cbfcfa946bcbec565d89a6bd2b26636d65ef3886437301be0ec6dbd2f1ef +size 1658603 diff --git a/checkpoint-2500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-2500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..c7dc9ef02db8c5572ae6f0dc9d44cdd1bb77f5c2 --- /dev/null +++ b/checkpoint-2500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:452469d6f04f02345bf22cc37fb2caed563c3a8c638d8eee0316e3e983a629fd +size 9586591 diff --git a/checkpoint-2500/random_states_0.pkl b/checkpoint-2500/random_states_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..472118dcc0b9459a8b0d370a498598549fd865d9 --- /dev/null +++ b/checkpoint-2500/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c2d5cd5da237cc96c33e0f51c8ce23bd00573b7f6c49d49c379c43fd53966d9b +size 14631 diff --git a/checkpoint-2500/scheduler.bin b/checkpoint-2500/scheduler.bin new file mode 100644 index 0000000000000000000000000000000000000000..2dc4ef41a5342fbdf24ea3fba16576916174aed7 --- /dev/null +++ b/checkpoint-2500/scheduler.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8419bbafa68cf47bd7cf9b8add092bd812c64073af54d05d40d5e656288d8e3 +size 559 diff --git a/checkpoint-2500/zero_to_fp32.py b/checkpoint-2500/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..e5249853c89179a3d5ad212ead8c464d6c979df5 --- /dev/null +++ b/checkpoint-2500/zero_to_fp32.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_optim_files(checkpoint_dir): + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_model_state(file): + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = { + k: v.float() + for k, + v in state_dict["module"].items() if k in buffer_names + } + param_shapes = state_dict[PARAM_SHAPES] + + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" + ) + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/checkpoint-3000/latest b/checkpoint-3000/latest new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8602be034ae63f23b293f8d037fa7afa0c54 --- /dev/null +++ b/checkpoint-3000/latest @@ -0,0 +1 @@ +pytorch_model \ No newline at end of file diff --git a/checkpoint-3000/pytorch_model/mp_rank_00_model_states.pt b/checkpoint-3000/pytorch_model/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..9306e0761356304680855c2ef791a004f92c6cba --- /dev/null +++ b/checkpoint-3000/pytorch_model/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63e08adf3346a6cd26722df9173da2ad9ab8513eecfca42a62900f88bfd72e1e +size 1658603 diff --git a/checkpoint-3000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-3000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..419f50e29381f20398352f15a77a38fe9634fda2 --- /dev/null +++ b/checkpoint-3000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:92ecacc3f2c6e1268547436692d8701de60af8c0a8995d51bd801b638dfb4cbd +size 9586591 diff --git a/checkpoint-3000/random_states_0.pkl b/checkpoint-3000/random_states_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..a43f6df0173620f2b121c954cc1fc8259fc2926e --- /dev/null +++ b/checkpoint-3000/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:323fd334a02ff024e76fc1fd6ec956772e1aa36c8bc46b39641ea02a81a1cf72 +size 14631 diff --git a/checkpoint-3000/scheduler.bin b/checkpoint-3000/scheduler.bin new file mode 100644 index 0000000000000000000000000000000000000000..9573016397ce5a3f1874d1f3923a8d41e60f81de --- /dev/null +++ b/checkpoint-3000/scheduler.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a108147e06a70650bad481dc674220c751b0b548ef761db450a9714bcc04543 +size 559 diff --git a/checkpoint-3000/zero_to_fp32.py b/checkpoint-3000/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..e5249853c89179a3d5ad212ead8c464d6c979df5 --- /dev/null +++ b/checkpoint-3000/zero_to_fp32.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_optim_files(checkpoint_dir): + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_model_state(file): + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = { + k: v.float() + for k, + v in state_dict["module"].items() if k in buffer_names + } + param_shapes = state_dict[PARAM_SHAPES] + + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" + ) + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/checkpoint-3500/latest b/checkpoint-3500/latest new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8602be034ae63f23b293f8d037fa7afa0c54 --- /dev/null +++ b/checkpoint-3500/latest @@ -0,0 +1 @@ +pytorch_model \ No newline at end of file diff --git a/checkpoint-3500/pytorch_model/mp_rank_00_model_states.pt b/checkpoint-3500/pytorch_model/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..67e3e116852379a64f81b7df6b0dd944ca905f77 --- /dev/null +++ b/checkpoint-3500/pytorch_model/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2bb929a1ae068c10e06d614d02ce7858086e9f1b26998047a58d83709394080e +size 1658603 diff --git a/checkpoint-3500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-3500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..99a29405f310f50d0b4d83072dc0637eb04af971 --- /dev/null +++ b/checkpoint-3500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e32d2e8755d1de48a650c24cdc151b3fbdf04e5d3528e6367f633350f7613cd +size 9586591 diff --git a/checkpoint-3500/random_states_0.pkl b/checkpoint-3500/random_states_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..204829de464e791e25c8e6aaed915126d84fcf8b --- /dev/null +++ b/checkpoint-3500/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37e4039f32fbf1b3f1005bc281c39dad9f3be2c01ae0258e93996d1f75524ef5 +size 14631 diff --git a/checkpoint-3500/scheduler.bin b/checkpoint-3500/scheduler.bin new file mode 100644 index 0000000000000000000000000000000000000000..003d4a044df091bde8c2bd25cd4c56496084158b --- /dev/null +++ b/checkpoint-3500/scheduler.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5b849c2e31cb13093da2ec90549fbd6c46a45d8cebba040880c4e0d3c5a65ac +size 559 diff --git a/checkpoint-3500/zero_to_fp32.py b/checkpoint-3500/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..e5249853c89179a3d5ad212ead8c464d6c979df5 --- /dev/null +++ b/checkpoint-3500/zero_to_fp32.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_optim_files(checkpoint_dir): + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_model_state(file): + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = { + k: v.float() + for k, + v in state_dict["module"].items() if k in buffer_names + } + param_shapes = state_dict[PARAM_SHAPES] + + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" + ) + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/checkpoint-4000/latest b/checkpoint-4000/latest new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8602be034ae63f23b293f8d037fa7afa0c54 --- /dev/null +++ b/checkpoint-4000/latest @@ -0,0 +1 @@ +pytorch_model \ No newline at end of file diff --git a/checkpoint-4000/pytorch_model/mp_rank_00_model_states.pt b/checkpoint-4000/pytorch_model/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..8c4cf1b12766ec7798f8282d59e3151b4bd4a2bf --- /dev/null +++ b/checkpoint-4000/pytorch_model/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9442e802bd6c1d36a819f1441f87eec1c57ed193c5c0577f68e1dc61ac973dd0 +size 1658603 diff --git a/checkpoint-4000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-4000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..85a82388f3bacf638849cc249ec6a791c0b8eb70 --- /dev/null +++ b/checkpoint-4000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8608f8552d2db5146e71513cb00d22e96b3aeff91127fa39db1d522ecdbca6b0 +size 9586591 diff --git a/checkpoint-4000/random_states_0.pkl b/checkpoint-4000/random_states_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..6359dd2777ace056ad2f5d17d37ecb2c76d64127 --- /dev/null +++ b/checkpoint-4000/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2e1264245326d49031775e35b25867d85d00bf11d177048ab50f3755bde2f545 +size 14631 diff --git a/checkpoint-4000/scheduler.bin b/checkpoint-4000/scheduler.bin new file mode 100644 index 0000000000000000000000000000000000000000..9893a3672c2d06ec50067a5bb481f597af5b7fde --- /dev/null +++ b/checkpoint-4000/scheduler.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd1644d8923601113d06c7895a6127ab81875778e451cd1a62d9e929c8971dde +size 559 diff --git a/checkpoint-4000/zero_to_fp32.py b/checkpoint-4000/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..e5249853c89179a3d5ad212ead8c464d6c979df5 --- /dev/null +++ b/checkpoint-4000/zero_to_fp32.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_optim_files(checkpoint_dir): + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_model_state(file): + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = { + k: v.float() + for k, + v in state_dict["module"].items() if k in buffer_names + } + param_shapes = state_dict[PARAM_SHAPES] + + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" + ) + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/checkpoint-4500/latest b/checkpoint-4500/latest new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8602be034ae63f23b293f8d037fa7afa0c54 --- /dev/null +++ b/checkpoint-4500/latest @@ -0,0 +1 @@ +pytorch_model \ No newline at end of file diff --git a/checkpoint-4500/pytorch_model/mp_rank_00_model_states.pt b/checkpoint-4500/pytorch_model/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..c812afd61323a376114e649fbbc7e739ea993867 --- /dev/null +++ b/checkpoint-4500/pytorch_model/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a0b4336583375d01d5c490fba8f1e8163032c2a6ab84b0cc6a524b25210a4b23 +size 1658603 diff --git a/checkpoint-4500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-4500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..5015c734505a33df7bd958405685356441bc3d9a --- /dev/null +++ b/checkpoint-4500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cbe1d0369c0bf0d9a0dc2c36d30bd965b2c7777ed8dcbd1a6f8db40fd4399e54 +size 9586591 diff --git a/checkpoint-4500/random_states_0.pkl b/checkpoint-4500/random_states_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..13a1f1429025aefd9b75217b495a2318565a59ff --- /dev/null +++ b/checkpoint-4500/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:73aa415afd8d6e66adb2568119d42cde424c6062e4e77f52af8ea77f09f2422e +size 14631 diff --git a/checkpoint-4500/scheduler.bin b/checkpoint-4500/scheduler.bin new file mode 100644 index 0000000000000000000000000000000000000000..85e177f77181395e39b17a4aab1b64d1f06cbf63 --- /dev/null +++ b/checkpoint-4500/scheduler.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:764da6777b46d92cbd141aec48e4e7cf13950250c810fa0323cc941298762fd7 +size 559 diff --git a/checkpoint-4500/zero_to_fp32.py b/checkpoint-4500/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..e5249853c89179a3d5ad212ead8c464d6c979df5 --- /dev/null +++ b/checkpoint-4500/zero_to_fp32.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_optim_files(checkpoint_dir): + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_model_state(file): + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = { + k: v.float() + for k, + v in state_dict["module"].items() if k in buffer_names + } + param_shapes = state_dict[PARAM_SHAPES] + + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" + ) + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/checkpoint-500/latest b/checkpoint-500/latest new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8602be034ae63f23b293f8d037fa7afa0c54 --- /dev/null +++ b/checkpoint-500/latest @@ -0,0 +1 @@ +pytorch_model \ No newline at end of file diff --git a/checkpoint-500/pytorch_model/mp_rank_00_model_states.pt b/checkpoint-500/pytorch_model/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..6c7ae647bbeeb7d7096d360cf482cc920f9bd776 --- /dev/null +++ b/checkpoint-500/pytorch_model/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1b41604055c461a5ed00017ee79bebc161833d27f0211e14435269af683fb2e9 +size 1658603 diff --git a/checkpoint-500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..4b6125aeeed6c68a9da836567dcaa8db5dd8c22b --- /dev/null +++ b/checkpoint-500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1597f2e79537d704348fbc4722861cca86b8d62ce6ae76ca4a0dc28245fe6b12 +size 9586591 diff --git a/checkpoint-500/random_states_0.pkl b/checkpoint-500/random_states_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..a0db94f5ac28d3d39bd4c155685aa77923851e2f --- /dev/null +++ b/checkpoint-500/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:233ff3728c6f7c487d5e0bb5ae12832266f52f4b72e11efa27112ff361fc48de +size 14631 diff --git a/checkpoint-500/scheduler.bin b/checkpoint-500/scheduler.bin new file mode 100644 index 0000000000000000000000000000000000000000..895b50762ee5c65de12093cc393594b1d89e17e1 --- /dev/null +++ b/checkpoint-500/scheduler.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5227fbc11d58a7ac396f456407e541d65674f34bc79fa2902a74e7cb01220e8d +size 559 diff --git a/checkpoint-500/zero_to_fp32.py b/checkpoint-500/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..e5249853c89179a3d5ad212ead8c464d6c979df5 --- /dev/null +++ b/checkpoint-500/zero_to_fp32.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_optim_files(checkpoint_dir): + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_model_state(file): + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = { + k: v.float() + for k, + v in state_dict["module"].items() if k in buffer_names + } + param_shapes = state_dict[PARAM_SHAPES] + + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" + ) + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/checkpoint-5000/latest b/checkpoint-5000/latest new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8602be034ae63f23b293f8d037fa7afa0c54 --- /dev/null +++ b/checkpoint-5000/latest @@ -0,0 +1 @@ +pytorch_model \ No newline at end of file diff --git a/checkpoint-5000/pytorch_model/mp_rank_00_model_states.pt b/checkpoint-5000/pytorch_model/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..a2d3e9caa0659ef6076ddb4bb64a4e08138b4da7 --- /dev/null +++ b/checkpoint-5000/pytorch_model/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c72fe4b20f647bd44a3fc5994ff11520912e6a2600207bd90896eb6b53b2052e +size 1658603 diff --git a/checkpoint-5000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-5000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..182e14b2fedbd56e098e0677eb52145a2537891d --- /dev/null +++ b/checkpoint-5000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:093a3ade2bc7d6905b19ae689bcdcbc6de82cae93d2f2bcd4c17d8b72830115d +size 9586591 diff --git a/checkpoint-5000/random_states_0.pkl b/checkpoint-5000/random_states_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..ce5ce361b13c40407b4f6c416c401678afcb6913 --- /dev/null +++ b/checkpoint-5000/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7708516cc7d289a250a372ef38e4ce475afa6e0636eec3b5acaae09a064a6c68 +size 14631 diff --git a/checkpoint-5000/scheduler.bin b/checkpoint-5000/scheduler.bin new file mode 100644 index 0000000000000000000000000000000000000000..fad8a9f9baa6cd97932e5b25c9dcf25cd48416dc --- /dev/null +++ b/checkpoint-5000/scheduler.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cafe10e76f2de5841220e83586e32d1bc549d9b2db65bfc7ffa1d3b78e1c5847 +size 559 diff --git a/checkpoint-5000/zero_to_fp32.py b/checkpoint-5000/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..e5249853c89179a3d5ad212ead8c464d6c979df5 --- /dev/null +++ b/checkpoint-5000/zero_to_fp32.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_optim_files(checkpoint_dir): + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_model_state(file): + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = { + k: v.float() + for k, + v in state_dict["module"].items() if k in buffer_names + } + param_shapes = state_dict[PARAM_SHAPES] + + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" + ) + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/checkpoint-5500/latest b/checkpoint-5500/latest new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8602be034ae63f23b293f8d037fa7afa0c54 --- /dev/null +++ b/checkpoint-5500/latest @@ -0,0 +1 @@ +pytorch_model \ No newline at end of file diff --git a/checkpoint-5500/pytorch_model/mp_rank_00_model_states.pt b/checkpoint-5500/pytorch_model/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..58c9d3c9ef305e243b1465c6d95829e57fbbf555 --- /dev/null +++ b/checkpoint-5500/pytorch_model/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7503d87eb839bf1010393667d78682058af8b6e2261ed819651dfcd90223cce7 +size 1658603 diff --git a/checkpoint-5500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-5500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..e9f4542ae65b61927ce61172eb556a1f89b63fc8 --- /dev/null +++ b/checkpoint-5500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1fadbcd8f2e8d6a3bd264f0931ddb930784d1f617ab1b3d839aaf365124c34b7 +size 9586591 diff --git a/checkpoint-5500/random_states_0.pkl b/checkpoint-5500/random_states_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..24c2cec624f33fcc09f25821e54908c7cb08a4ed --- /dev/null +++ b/checkpoint-5500/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6c6f5f85e7588e0a98e8f402775a8bafb4b091f25536c005e1d101c9756ef9ad +size 14631 diff --git a/checkpoint-5500/scheduler.bin b/checkpoint-5500/scheduler.bin new file mode 100644 index 0000000000000000000000000000000000000000..d1100706d8b8f4dafbcfc3c33ee8666efc4140dc --- /dev/null +++ b/checkpoint-5500/scheduler.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f5136f7d5e113aea32fc5a7184f6e3e519b8917572a3f5cb914fbf149f47945 +size 559 diff --git a/checkpoint-5500/zero_to_fp32.py b/checkpoint-5500/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..e5249853c89179a3d5ad212ead8c464d6c979df5 --- /dev/null +++ b/checkpoint-5500/zero_to_fp32.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_optim_files(checkpoint_dir): + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_model_state(file): + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = { + k: v.float() + for k, + v in state_dict["module"].items() if k in buffer_names + } + param_shapes = state_dict[PARAM_SHAPES] + + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" + ) + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/checkpoint-6000/latest b/checkpoint-6000/latest new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8602be034ae63f23b293f8d037fa7afa0c54 --- /dev/null +++ b/checkpoint-6000/latest @@ -0,0 +1 @@ +pytorch_model \ No newline at end of file diff --git a/checkpoint-6000/pytorch_model/mp_rank_00_model_states.pt b/checkpoint-6000/pytorch_model/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..b6688fb20daadf49fd2ea7fef1218cbcd3b5bb3e --- /dev/null +++ b/checkpoint-6000/pytorch_model/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:28ad544805f1f73b3f389695c34eacefcdb9582dcd77442eb4ab6e9bb30149d4 +size 1658603 diff --git a/checkpoint-6000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-6000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..77edc8bc763dedafb05e8023e555fb001a0e22d1 --- /dev/null +++ b/checkpoint-6000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b638ae5acfc32c60b0d988b618e395e89883e8c6c46c8de1fe3558a2a50f52ed +size 9586591 diff --git a/checkpoint-6000/random_states_0.pkl b/checkpoint-6000/random_states_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..6490f492f2a101137343d6d73c67608880af7bc5 --- /dev/null +++ b/checkpoint-6000/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a343d7e430a50eaadbf75be6ecb6e0ed0a5e64fe847843d7db979a371ccc3c87 +size 14631 diff --git a/checkpoint-6000/scheduler.bin b/checkpoint-6000/scheduler.bin new file mode 100644 index 0000000000000000000000000000000000000000..d989a50e5615e76d62468a987ad631cd06e03a71 --- /dev/null +++ b/checkpoint-6000/scheduler.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9587decdba47406f332d08ca89ddf17ff43c7a90bbb247590ca8604591f66ef8 +size 559 diff --git a/checkpoint-6000/zero_to_fp32.py b/checkpoint-6000/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..e5249853c89179a3d5ad212ead8c464d6c979df5 --- /dev/null +++ b/checkpoint-6000/zero_to_fp32.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_optim_files(checkpoint_dir): + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_model_state(file): + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = { + k: v.float() + for k, + v in state_dict["module"].items() if k in buffer_names + } + param_shapes = state_dict[PARAM_SHAPES] + + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" + ) + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/checkpoint-6500/latest b/checkpoint-6500/latest new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8602be034ae63f23b293f8d037fa7afa0c54 --- /dev/null +++ b/checkpoint-6500/latest @@ -0,0 +1 @@ +pytorch_model \ No newline at end of file diff --git a/checkpoint-6500/pytorch_model/mp_rank_00_model_states.pt b/checkpoint-6500/pytorch_model/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..94f5ad4f8d959f294feba463145e465114913c51 --- /dev/null +++ b/checkpoint-6500/pytorch_model/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f2d255cc2b8b17f3207cd8a5f6c8b076289d2888c9baf2e86b6a1b596c72ede +size 1658603 diff --git a/checkpoint-6500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-6500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..2de3afd55372b402e35a3701ca2857deab24a4b7 --- /dev/null +++ b/checkpoint-6500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db249e8b96291303960a61944b03c5846deb795a2862ff5d8746a1ef09e47def +size 9586591 diff --git a/checkpoint-6500/random_states_0.pkl b/checkpoint-6500/random_states_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..4b31f8c9f76b3abcc9f18724022b274f2378700d --- /dev/null +++ b/checkpoint-6500/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c62939b33d5f054e31bfac79fc0bc75445c2e2c2b8abb43889781c521fee522 +size 14631 diff --git a/checkpoint-6500/scheduler.bin b/checkpoint-6500/scheduler.bin new file mode 100644 index 0000000000000000000000000000000000000000..dcd710cc8816712fc0dd3ec20ddfbb40e27f1690 --- /dev/null +++ b/checkpoint-6500/scheduler.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:edc08654b9d8e85fd12c1115c40af956a9f51f4a991229ad01cdb85855ad37e3 +size 559 diff --git a/checkpoint-6500/zero_to_fp32.py b/checkpoint-6500/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..e5249853c89179a3d5ad212ead8c464d6c979df5 --- /dev/null +++ b/checkpoint-6500/zero_to_fp32.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_optim_files(checkpoint_dir): + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_model_state(file): + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = { + k: v.float() + for k, + v in state_dict["module"].items() if k in buffer_names + } + param_shapes = state_dict[PARAM_SHAPES] + + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" + ) + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/checkpoint-7000/latest b/checkpoint-7000/latest new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8602be034ae63f23b293f8d037fa7afa0c54 --- /dev/null +++ b/checkpoint-7000/latest @@ -0,0 +1 @@ +pytorch_model \ No newline at end of file diff --git a/checkpoint-7000/pytorch_model/mp_rank_00_model_states.pt b/checkpoint-7000/pytorch_model/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..0cba80ab05f65a65fef35922a030bfb9a96d9dce --- /dev/null +++ b/checkpoint-7000/pytorch_model/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:17205f192fa2909d961c5e869bc46d91e0d713bdfbab174d9085305f1d16ab9d +size 1658603 diff --git a/checkpoint-7000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-7000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..3af3e959f107ffe12649d8530ef134020729137a --- /dev/null +++ b/checkpoint-7000/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f39f2ddc5706dcdcff9552160aef1fded7a3ce35b01e5b9e338505a4e8d781d +size 9586591 diff --git a/checkpoint-7000/random_states_0.pkl b/checkpoint-7000/random_states_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..3d6f5f594fcb06e3e69f4c57550f99a7fa32e250 --- /dev/null +++ b/checkpoint-7000/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5061c9d72a54ec3317a3e1e6c062af1a587d586eeb5300d6932c959f04c5dde5 +size 14631 diff --git a/checkpoint-7000/scheduler.bin b/checkpoint-7000/scheduler.bin new file mode 100644 index 0000000000000000000000000000000000000000..91a56e4265ac883f2d39d00e4bf386a94a1dc16d --- /dev/null +++ b/checkpoint-7000/scheduler.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ee84ae10b603265dc3c0feea5f489a616ba55b5700671ea09d81a2161d59bef +size 559 diff --git a/checkpoint-7000/zero_to_fp32.py b/checkpoint-7000/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..e5249853c89179a3d5ad212ead8c464d6c979df5 --- /dev/null +++ b/checkpoint-7000/zero_to_fp32.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_optim_files(checkpoint_dir): + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_model_state(file): + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = { + k: v.float() + for k, + v in state_dict["module"].items() if k in buffer_names + } + param_shapes = state_dict[PARAM_SHAPES] + + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" + ) + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/checkpoint-7500/latest b/checkpoint-7500/latest new file mode 100644 index 0000000000000000000000000000000000000000..7b2c8602be034ae63f23b293f8d037fa7afa0c54 --- /dev/null +++ b/checkpoint-7500/latest @@ -0,0 +1 @@ +pytorch_model \ No newline at end of file diff --git a/checkpoint-7500/pytorch_model/mp_rank_00_model_states.pt b/checkpoint-7500/pytorch_model/mp_rank_00_model_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..cf91f942d68b80943cc890f212f3c45bc0eb5aa7 --- /dev/null +++ b/checkpoint-7500/pytorch_model/mp_rank_00_model_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a216b279f95f4448f8ba027693f2a2b05f237aa982bdd604dbfc4b5aa88c9d53 +size 1658603 diff --git a/checkpoint-7500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt b/checkpoint-7500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt new file mode 100644 index 0000000000000000000000000000000000000000..637cef5c2e5cc669d6e4fad8e26f5d77324ba88f --- /dev/null +++ b/checkpoint-7500/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a2755e764a831b8354efd0a6aec9bbbb5b88e719b3818fc3c64dea5ea4f8a9c +size 9586591 diff --git a/checkpoint-7500/random_states_0.pkl b/checkpoint-7500/random_states_0.pkl new file mode 100644 index 0000000000000000000000000000000000000000..416dcd1e2f40f6b5bbad474bf483d0c586081ce7 --- /dev/null +++ b/checkpoint-7500/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a22fcb6455ef3518eeb11a4aa9481339f03f0691106b9a33aad3f467c40f7863 +size 14631 diff --git a/checkpoint-7500/scheduler.bin b/checkpoint-7500/scheduler.bin new file mode 100644 index 0000000000000000000000000000000000000000..1aee67755a70266275aa4bb93be8ed1ebf2dacd3 --- /dev/null +++ b/checkpoint-7500/scheduler.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:248165686a5528c897c73f3bfc7fd43d5c47a969f6c78429e39dd7241a8db1a8 +size 559 diff --git a/checkpoint-7500/zero_to_fp32.py b/checkpoint-7500/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..e5249853c89179a3d5ad212ead8c464d6c979df5 --- /dev/null +++ b/checkpoint-7500/zero_to_fp32.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, + OPTIMIZER_STATE_DICT, + SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, + ZERO_STAGE, + PARTITION_COUNT, + PARAM_SHAPES, + BUFFER_NAMES) + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_optim_files(checkpoint_dir): + # XXX: need to test that this simple glob rule works for multi-node setup too + optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, + "*_optim_states.pt")), + key=natural_keys) + + if len(optim_files) == 0: + raise FileNotFoundError( + f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'") + + return optim_files + + +def parse_model_state(file): + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = { + k: v.float() + for k, + v in state_dict["module"].items() if k in buffer_names + } + param_shapes = state_dict[PARAM_SHAPES] + + ds_version = state_dict.get(DS_VERSION, None) + + return buffers, param_shapes, ds_version + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for f in files: + state_dicts.append(torch.load(f, map_location=device)) + + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + if zero_stage == 2: + fp32_flat_groups = [ + state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] + for i in range(len(state_dicts)) + ] + elif zero_stage == 3: + # if there is more than one param group, there will be multiple flattened tensors - one + # flattened tensor per group - for simplicity merge them into a single tensor + # + # XXX: could make the script more memory efficient for when there are multiple groups - it + # will require matching the sub-lists of param_shapes for each param group flattened tensor + + fp32_flat_groups = [ + torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], + 0) for i in range(len(state_dicts)) + ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print( + f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_file = get_model_state_file(ds_checkpoint_dir, zero_stage) + buffers, param_shapes, ds_version = parse_model_state(model_file) + print(f'Parsing checkpoint created by deepspeed=={ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers) + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print( + f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum([ + full_single_fp32_vector.numel() + for full_single_fp32_vector in merged_single_partition_of_fp32_groups + ]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum( + [sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print( + f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} " + ) + state_dict[name] = full_single_fp32_vector.narrow( + 0, + offset, + unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, + param_shapes, + fp32_flat_groups, + buffers): + + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + avail_numel = fp32_flat_groups[0].numel() * world_size + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + state_dict = OrderedDict() + + # buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in param_shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, + offset, + partitioned_numel) + for i in range(world_size)), + 0).narrow(0, + 0, + unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError( + f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print( + f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements" + ) + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help= + "path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)" + ) + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/image_0.png b/image_0.png new file mode 100644 index 0000000000000000000000000000000000000000..6bedeb7111dec9e68cf439c37a2159067200fc8d --- /dev/null +++ b/image_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04c814eed56c8473ee52cdf4e2bc12aca27eb179a496c7c899f28c0f4b9d7be6 +size 395288 diff --git a/image_1.png b/image_1.png new file mode 100644 index 0000000000000000000000000000000000000000..11d4d06171b340f01c766a0d901d5c7bebec7489 --- /dev/null +++ b/image_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e9c4a470ad900801f7de4f9402eb27af8a1cc00eac80d618ef16bac39fb27d33 +size 842 diff --git a/image_2.png b/image_2.png new file mode 100644 index 0000000000000000000000000000000000000000..032a550692ddb3baf12a38e2fb39c472c8e1cfc6 --- /dev/null +++ b/image_2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc33e6a67b54549cbbcb98ce5702a09120a5e4a31b772dff377861c787906143 +size 392031 diff --git a/image_3.png b/image_3.png new file mode 100644 index 0000000000000000000000000000000000000000..04bed8f639a98c90b58651568e12c15e280c17bd --- /dev/null +++ b/image_3.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fd754da6c17fc04b626830e4cb34b3b5dfabf8a07d332b6d0c400c41760572d4 +size 396875 diff --git a/pytorch_lora_weights.bin b/pytorch_lora_weights.bin new file mode 100644 index 0000000000000000000000000000000000000000..5e9877c105293f8ff9ebe53d0ec99d60e51573cc --- /dev/null +++ b/pytorch_lora_weights.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2456fcdb3f18466d1983e6b695911ab360ae33b8614a51058273edc43a3257e +size 3283841