diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index e5974a30df22..c5f84753022b 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -5,16 +5,18 @@ # DeepSpeed Team -from functools import partial -from itertools import chain import argparse import glob import itertools import math -from concurrent.futures import ProcessPoolExecutor import os import re import shutil +from collections import OrderedDict +from concurrent.futures import ProcessPoolExecutor +from functools import partial +from itertools import chain + import torch import tqdm #from pprint import pprint @@ -109,7 +111,7 @@ def _save_checkpoint(file_path, chkpt_sd): torch.save(chkpt_sd, file_path) -def extract_zero_shards(dir, ds_checkpoint, indices_3D): +def extract_zero_shards(dir, ds_checkpoint, weight_only, data_type, indices_3D): pp_index, tp_index, dp_index = indices_3D sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index) @@ -121,19 +123,20 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D): pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETER_PATTERNS, []) # print(f'{pipeline_replicated_params=}') - # dict state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"] - # list fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS] - param_groups_cnt = len(state_groups) - for param_group_id in range(param_groups_cnt): + param_state = OrderedDict() - flat_state = dict( - exp_avg=state_groups[param_group_id]["exp_avg"], - exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"], - fp32=fp32_groups[param_group_id], - ) + for param_group_id in range(len(state_groups)): + if weight_only: + flat_state = dict(fp32=fp32_groups[param_group_id].detach(), ) + else: + flat_state = dict( + exp_avg=state_groups[param_group_id]["exp_avg"], + exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"], + fp32=fp32_groups[param_group_id], + ) if "step" in state_groups[param_group_id]: flat_state["step"] = state_groups[param_group_id]["step"] @@ -145,18 +148,25 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D): # pprint(f"dpt{dp_index}{pp_index}{tp_index} {param_group_id} {name} => {fragment_mapping.start}:{fragment_mapping.numel}") for state_key in flat_state.keys(): - dump_param_fragment(dir, tp_index, dp_index, state_key, flat_state[state_key], name, - fragment_mapping.start, fragment_mapping.numel) + dump_param_fragment(param_state, dir, tp_index, dp_index, state_key, flat_state[state_key], name, + fragment_mapping.start, fragment_mapping.numel, data_type, weight_only) + + return dp_index, param_state -def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index): +def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, weight_only, data_type, dp_index): state_dict = torch.load(optim_files[dp_index], map_location='cpu') - flat_state = dict( - exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"], - exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg_sq"], - fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0], - ) + param_state = OrderedDict() + + if weight_only: + flat_state = dict(fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0].detach(), ) + else: + flat_state = dict( + exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"], + exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg_sq"], + fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0], + ) offset = 0 for name, shape in param_shapes.items(): @@ -164,10 +174,12 @@ def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, d partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, dp_degree) padding_free_numel = min(partitioned_numel, abs(unpartitioned_numel - dp_index * partitioned_numel)) for state_key in flat_state.keys(): - dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset, - padding_free_numel) + dump_param_fragment(param_state, temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset, + padding_free_numel, data_type, weight_only) offset += partitioned_numel + return dp_index, param_state + cnt = 0 @@ -176,23 +188,29 @@ def dp_index_to_str(dp_index): return f"{dp_index:0>2d}" -def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel): +def dump_param_fragment(param_state, dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel, + data_type, weight_only): global cnt # temp hack - param_base_path = os.path.join(dir, param_name, str(tp_index)) - os.makedirs(param_base_path, exist_ok=True) - cnt += 1 - path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}") - - #print(f"{param_name}: {offset}: {numel} => {path}") - # State might be a python int or a tensor if state_name != "step" and torch.is_tensor(state_flat_tensor): state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone() - _save_checkpoint(path, state_flat_tensor) + + if data_type == "FP16": + state_flat_tensor = state_flat_tensor.to(torch.float16) + elif data_type == "BF16": + state_flat_tensor = state_flat_tensor.to(torch.bfloat16) + + if weight_only: + param_state[param_name] = state_flat_tensor + else: + param_base_path = os.path.join(dir, param_name, str(tp_index)) + os.makedirs(param_base_path, exist_ok=True) + path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}") + _save_checkpoint(path, state_flat_tensor) def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None): @@ -360,19 +378,26 @@ def _do_parallel_work(do_work, work_chunks, num_workers): return results -def _extract_zero_shard_files(args, ds_checkpoint, temp_dir): +def _extract_zero_shard_files(args, ds_checkpoint, temp_dir, weight_only=False, data_type="FP32"): _3d_range_list = list( itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree), range(ds_checkpoint.dp_degree))) #pprint(f'{_3d_range_list=}') - do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint) - _do_parallel_work(do_work, _3d_range_list, args.num_extract_workers) + do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint, weight_only, data_type) + return _do_parallel_work(do_work, _3d_range_list, args.num_extract_workers) -def _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir): - do_work = partial(extract_zero_shards_stage3, optim_files, param_shapes, dp_degree, temp_dir) - _do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers) +def _extract_zero_shard_files_stage3(args, + optim_files, + param_shapes, + dp_degree, + temp_dir, + weight_only=False, + data_type="FP32"): + do_work = partial(extract_zero_shards_stage3, optim_files, param_shapes, dp_degree, temp_dir, weight_only, + data_type) + return _do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers) def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir): diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index e69ecd9acb5a..f864d333ec0b 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -33,6 +33,10 @@ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) +from deepspeed.checkpoint import DeepSpeedCheckpoint + +from deepspeed.checkpoint.ds_to_universal import _inject_missing_state, _extract_zero_shard_files, _extract_zero_shard_files_stage3, _get_model_state_files, _parse_model_states_stage3, _get_optim_files + @dataclass class zero_model_state: @@ -99,6 +103,7 @@ def get_model_state_files(checkpoint_dir): def parse_model_states(files): zero_model_states = [] + zero_stage = None for file in files: state_dict = torch.load(file, map_location=device) @@ -140,7 +145,9 @@ def parse_model_states(files): frozen_param_fragments=frozen_param_fragments) zero_model_states.append(z_model_state) - return zero_model_states + if zero_stage is None: + zero_stage = state_dict['ds_config']['zero_optimization']['stage'] + return zero_stage, zero_model_states def parse_optim_states(files, ds_checkpoint_dir): @@ -205,20 +212,15 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_ """ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") - optim_files = get_optim_files(ds_checkpoint_dir) - zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) - print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") - model_files = get_model_state_files(ds_checkpoint_dir) - - zero_model_states = parse_model_states(model_files) + zero_stage, zero_model_states = parse_model_states(model_files) print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') if zero_stage <= 2: - return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + return _get_fp32_state_dict_from_zero2_checkpoint(ds_checkpoint_dir, zero_model_states, exclude_frozen_parameters) elif zero_stage == 3: - return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + return _get_fp32_state_dict_from_zero3_checkpoint(ds_checkpoint_dir, zero_model_states, exclude_frozen_parameters) @@ -332,10 +334,22 @@ def zero2_align(x): print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") -def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, - exclude_frozen_parameters): +def _consolidate_ucp_checkpoints(args, state_dict, slice_shapes): + zero_output_folder = os.path.join(args.output_dir, "zero") + + for param in slice_shapes.keys(): + ucp_checkpoint_path = os.path.join(zero_output_folder, param, "fp32.pt") + weight = torch.load(ucp_checkpoint_path, map_location=device) + state_dict[param] = weight['param'] + + +def _get_fp32_state_dict_from_zero2_checkpoint(ds_checkpoint_dir, zero_model_states, exclude_frozen_parameters): + state_dict = OrderedDict() + ds_checkpoint = DeepSpeedCheckpoint(ds_checkpoint_dir) + _inject_missing_state(ds_checkpoint) + # buffers buffers = zero_model_states[0].buffers state_dict.update(buffers) @@ -345,7 +359,20 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zer if not exclude_frozen_parameters: _zero2_merge_frozen_params(state_dict, zero_model_states) - _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + param_shards = _extract_zero_shard_files(args, + ds_checkpoint, + temp_dir=None, + weight_only=True, + data_type=args.data_type) + + param_shards.sort(key=lambda x: x[0]) + + for _, param in param_shards: + for key, value in param.items(): + if key in state_dict: + state_dict[key] = torch.cat((state_dict[key], value), 0) + else: + state_dict[key] = value # recover shared parameters for pair in zero_model_states[0].shared_params: @@ -450,10 +477,15 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") -def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, - exclude_frozen_parameters): +def _get_fp32_state_dict_from_zero3_checkpoint(ds_checkpoint_dir, zero_model_states, exclude_frozen_parameters): state_dict = OrderedDict() + model_files = _get_model_state_files(ds_checkpoint_dir) + optim_files = _get_optim_files(ds_checkpoint_dir) + param_shapes = _parse_model_states_stage3(model_files) + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + world_size = len(model_files) + # buffers buffers = zero_model_states[0].buffers state_dict.update(buffers) @@ -463,7 +495,22 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer if not exclude_frozen_parameters: _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) - _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + param_shards = _extract_zero_shard_files_stage3(args, + optim_files, + param_shapes, + world_size, + temp_dir=None, + weight_only=True, + data_type=args.data_type) + + param_shards.sort(key=lambda x: x[0]) + + for _, param in param_shards: + for key, value in param.items(): + if key in state_dict: + state_dict[key] = torch.cat((state_dict[key], value), 0) + else: + state_dict[key] = value # recover shared parameters for pair in zero_model_states[0].shared_params: @@ -475,7 +522,7 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): """ - Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + Convert ZeRO 2 or 3 checkpoint into a single fp32/fp16/bf16 consolidated state_dict that can be loaded with ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example via a model hub. @@ -555,7 +602,6 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, print('If you want to use `max_shard_size`, please `pip install huggingface_hub`') raise - # Convert zero checkpoint to state_dict state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) # Shard the model if it is too big. @@ -662,6 +708,19 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + parser.add_argument('--num_extract_workers', + default=4, + type=int, + help='How many parallel processes to extract zero shards') + parser.add_argument('--no_strict', + dest='strict', + action='store_false', + help='Do not perform validity checks on converted checkpoint.') + parser.add_argument( + '--data_type', + default='FP32', + choices=['FP32', 'FP16', 'BF16'], + help="Specify the output tensor data type format (FP32, FP16, BF16, FP8, BF8). Default is FP32.") args = parser.parse_args() debug = args.debug