From 7ee67344ab57eec30d100cbbfea2dc09ac4a22e0 Mon Sep 17 00:00:00 2001 From: Xinyu Lian Date: Tue, 22 Oct 2024 22:45:26 -0500 Subject: [PATCH 1/4] support the conversion from ZeRO checkpoints to FP32/FP16/BF16 parameter weights in parallel via UCP --- deepspeed/checkpoint/ds_to_universal.py | 101 +++++++++++++--------- deepspeed/utils/zero_to_fp32.py | 109 ++++++++++++++++++------ 2 files changed, 147 insertions(+), 63 deletions(-) diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index e5974a30df22..c5f84753022b 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -5,16 +5,18 @@ # DeepSpeed Team -from functools import partial -from itertools import chain import argparse import glob import itertools import math -from concurrent.futures import ProcessPoolExecutor import os import re import shutil +from collections import OrderedDict +from concurrent.futures import ProcessPoolExecutor +from functools import partial +from itertools import chain + import torch import tqdm #from pprint import pprint @@ -109,7 +111,7 @@ def _save_checkpoint(file_path, chkpt_sd): torch.save(chkpt_sd, file_path) -def extract_zero_shards(dir, ds_checkpoint, indices_3D): +def extract_zero_shards(dir, ds_checkpoint, weight_only, data_type, indices_3D): pp_index, tp_index, dp_index = indices_3D sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index) @@ -121,19 +123,20 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D): pipeline_replicated_params = universal_checkpoint_info.get(PIPELINE_REPLICATED_PARAMETER_PATTERNS, []) # print(f'{pipeline_replicated_params=}') - # dict state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"] - # list fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS] - param_groups_cnt = len(state_groups) - for param_group_id in range(param_groups_cnt): + param_state = OrderedDict() - flat_state = dict( - exp_avg=state_groups[param_group_id]["exp_avg"], - exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"], - fp32=fp32_groups[param_group_id], - ) + for param_group_id in range(len(state_groups)): + if weight_only: + flat_state = dict(fp32=fp32_groups[param_group_id].detach(), ) + else: + flat_state = dict( + exp_avg=state_groups[param_group_id]["exp_avg"], + exp_avg_sq=state_groups[param_group_id]["exp_avg_sq"], + fp32=fp32_groups[param_group_id], + ) if "step" in state_groups[param_group_id]: flat_state["step"] = state_groups[param_group_id]["step"] @@ -145,18 +148,25 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D): # pprint(f"dpt{dp_index}{pp_index}{tp_index} {param_group_id} {name} => {fragment_mapping.start}:{fragment_mapping.numel}") for state_key in flat_state.keys(): - dump_param_fragment(dir, tp_index, dp_index, state_key, flat_state[state_key], name, - fragment_mapping.start, fragment_mapping.numel) + dump_param_fragment(param_state, dir, tp_index, dp_index, state_key, flat_state[state_key], name, + fragment_mapping.start, fragment_mapping.numel, data_type, weight_only) + + return dp_index, param_state -def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index): +def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, weight_only, data_type, dp_index): state_dict = torch.load(optim_files[dp_index], map_location='cpu') - flat_state = dict( - exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"], - exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg_sq"], - fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0], - ) + param_state = OrderedDict() + + if weight_only: + flat_state = dict(fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0].detach(), ) + else: + flat_state = dict( + exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"], + exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg_sq"], + fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0], + ) offset = 0 for name, shape in param_shapes.items(): @@ -164,10 +174,12 @@ def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, d partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, dp_degree) padding_free_numel = min(partitioned_numel, abs(unpartitioned_numel - dp_index * partitioned_numel)) for state_key in flat_state.keys(): - dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset, - padding_free_numel) + dump_param_fragment(param_state, temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset, + padding_free_numel, data_type, weight_only) offset += partitioned_numel + return dp_index, param_state + cnt = 0 @@ -176,23 +188,29 @@ def dp_index_to_str(dp_index): return f"{dp_index:0>2d}" -def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel): +def dump_param_fragment(param_state, dir, tp_index, dp_index, state_name, state_flat_tensor, param_name, offset, numel, + data_type, weight_only): global cnt # temp hack - param_base_path = os.path.join(dir, param_name, str(tp_index)) - os.makedirs(param_base_path, exist_ok=True) - cnt += 1 - path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}") - - #print(f"{param_name}: {offset}: {numel} => {path}") - # State might be a python int or a tensor if state_name != "step" and torch.is_tensor(state_flat_tensor): state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone() - _save_checkpoint(path, state_flat_tensor) + + if data_type == "FP16": + state_flat_tensor = state_flat_tensor.to(torch.float16) + elif data_type == "BF16": + state_flat_tensor = state_flat_tensor.to(torch.bfloat16) + + if weight_only: + param_state[param_name] = state_flat_tensor + else: + param_base_path = os.path.join(dir, param_name, str(tp_index)) + os.makedirs(param_base_path, exist_ok=True) + path = os.path.join(param_base_path, f"{state_name}.{dp_index_to_str(dp_index)}") + _save_checkpoint(path, state_flat_tensor) def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape=None): @@ -360,19 +378,26 @@ def _do_parallel_work(do_work, work_chunks, num_workers): return results -def _extract_zero_shard_files(args, ds_checkpoint, temp_dir): +def _extract_zero_shard_files(args, ds_checkpoint, temp_dir, weight_only=False, data_type="FP32"): _3d_range_list = list( itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree), range(ds_checkpoint.dp_degree))) #pprint(f'{_3d_range_list=}') - do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint) - _do_parallel_work(do_work, _3d_range_list, args.num_extract_workers) + do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint, weight_only, data_type) + return _do_parallel_work(do_work, _3d_range_list, args.num_extract_workers) -def _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir): - do_work = partial(extract_zero_shards_stage3, optim_files, param_shapes, dp_degree, temp_dir) - _do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers) +def _extract_zero_shard_files_stage3(args, + optim_files, + param_shapes, + dp_degree, + temp_dir, + weight_only=False, + data_type="FP32"): + do_work = partial(extract_zero_shards_stage3, optim_files, param_shapes, dp_degree, temp_dir, weight_only, + data_type) + return _do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers) def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir): diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index e69ecd9acb5a..44af82840553 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -33,6 +33,10 @@ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) +from deepspeed.checkpoint import DeepSpeedCheckpoint + +from deepspeed.checkpoint.ds_to_universal import _inject_missing_state, _extract_zero_shard_files, _extract_zero_shard_files_stage3, _get_model_state_files, _parse_model_states_stage3, _get_optim_files + @dataclass class zero_model_state: @@ -99,6 +103,7 @@ def get_model_state_files(checkpoint_dir): def parse_model_states(files): zero_model_states = [] + zero_stage = None for file in files: state_dict = torch.load(file, map_location=device) @@ -140,14 +145,16 @@ def parse_model_states(files): frozen_param_fragments=frozen_param_fragments) zero_model_states.append(z_model_state) - return zero_model_states + if zero_stage is None: + zero_stage = state_dict['ds_config']['zero_optimization']['stage'] + return zero_stage, zero_model_states def parse_optim_states(files, ds_checkpoint_dir): total_files = len(files) state_dicts = [] for f in files: - state_dict = torch.load(f, map_location=device) + state_dict = torch.load(f, map_location='cpu') # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights # and also handle the case where it was already removed by another helper script state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) @@ -195,7 +202,7 @@ def parse_optim_states(files, ds_checkpoint_dir): return zero_stage, world_size, fp32_flat_groups -def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): +def _get_param_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): """ Returns fp32 state_dict reconstructed from ds checkpoint @@ -205,21 +212,16 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_ """ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") - optim_files = get_optim_files(ds_checkpoint_dir) - zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) - print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") - model_files = get_model_state_files(ds_checkpoint_dir) - - zero_model_states = parse_model_states(model_files) + zero_stage, zero_model_states = parse_model_states(model_files) print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') if zero_stage <= 2: - return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, - exclude_frozen_parameters) + return _get_param_state_dict_from_zero2_checkpoint(ds_checkpoint_dir, zero_model_states, + exclude_frozen_parameters) elif zero_stage == 3: - return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, - exclude_frozen_parameters) + return _get_param_state_dict_from_zero3_checkpoint(ds_checkpoint_dir, zero_model_states, + exclude_frozen_parameters) def _zero2_merge_frozen_params(state_dict, zero_model_states): @@ -332,10 +334,22 @@ def zero2_align(x): print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") -def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, - exclude_frozen_parameters): +def _consolidate_ucp_checkpoints(args, state_dict, slice_shapes): + zero_output_folder = os.path.join(args.output_dir, "zero") + + for param in slice_shapes.keys(): + ucp_checkpoint_path = os.path.join(zero_output_folder, param, "fp32.pt") + weight = torch.load(ucp_checkpoint_path, map_location=device) + state_dict[param] = weight['param'] + + +def _get_param_state_dict_from_zero2_checkpoint(ds_checkpoint_dir, zero_model_states, exclude_frozen_parameters): + state_dict = OrderedDict() + ds_checkpoint = DeepSpeedCheckpoint(ds_checkpoint_dir) + _inject_missing_state(ds_checkpoint) + # buffers buffers = zero_model_states[0].buffers state_dict.update(buffers) @@ -345,7 +359,20 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zer if not exclude_frozen_parameters: _zero2_merge_frozen_params(state_dict, zero_model_states) - _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + param_shards = _extract_zero_shard_files(args, + ds_checkpoint, + temp_dir=None, + weight_only=True, + data_type=args.data_type) + + param_shards.sort(key=lambda x: x[0]) + + for _, param in param_shards: + for key, value in param.items(): + if key in state_dict: + state_dict[key] = torch.cat((state_dict[key], value), 0) + else: + state_dict[key] = value # recover shared parameters for pair in zero_model_states[0].shared_params: @@ -450,10 +477,15 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") -def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, - exclude_frozen_parameters): +def _get_param_state_dict_from_zero3_checkpoint(ds_checkpoint_dir, zero_model_states, exclude_frozen_parameters): state_dict = OrderedDict() + model_files = _get_model_state_files(ds_checkpoint_dir) + optim_files = _get_optim_files(ds_checkpoint_dir) + param_shapes = _parse_model_states_stage3(model_files) + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + world_size = len(model_files) + # buffers buffers = zero_model_states[0].buffers state_dict.update(buffers) @@ -463,7 +495,22 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer if not exclude_frozen_parameters: _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) - _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + param_shards = _extract_zero_shard_files_stage3(args, + optim_files, + param_shapes, + world_size, + temp_dir=None, + weight_only=True, + data_type=args.data_type) + + param_shards.sort(key=lambda x: x[0]) + + for _, param in param_shards: + for key, value in param.items(): + if key in state_dict: + state_dict[key] = torch.cat((state_dict[key], value), 0) + else: + state_dict[key] = value # recover shared parameters for pair in zero_model_states[0].shared_params: @@ -473,9 +520,9 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer return state_dict -def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): +def get_param_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): """ - Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + Convert ZeRO 2 or 3 checkpoint into a single fp32/fp16/bf16 consolidated state_dict that can be loaded with ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example via a model hub. @@ -520,7 +567,7 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f if not os.path.isdir(ds_checkpoint_dir): raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") - return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + return _get_param_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, @@ -555,8 +602,7 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, print('If you want to use `max_shard_size`, please `pip install huggingface_hub`') raise - # Convert zero checkpoint to state_dict - state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + state_dict = get_param_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) # Shard the model if it is too big. weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin" @@ -624,7 +670,7 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): """ logger.info(f"Extracting fp32 weights") - state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + state_dict = get_param_state_dict_from_zero_checkpoint(checkpoint_dir, tag) logger.info(f"Overwriting model with fp32 weights") model = model.cpu() @@ -662,6 +708,19 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + parser.add_argument('--num_extract_workers', + default=4, + type=int, + help='How many parallel processes to extract zero shards') + parser.add_argument('--no_strict', + dest='strict', + action='store_false', + help='Do not perform validity checks on converted checkpoint.') + parser.add_argument( + '--data_type', + default='FP32', + choices=['FP32', 'FP16', 'BF16'], + help="Specify the output tensor data type format (FP32, FP16, BF16, FP8, BF8). Default is FP32.") args = parser.parse_args() debug = args.debug From 869c455f1271f165f2712589b3effbc4ba7e0a33 Mon Sep 17 00:00:00 2001 From: Xinyu Lian Date: Wed, 23 Oct 2024 10:14:36 -0500 Subject: [PATCH 2/4] not change the function name for compatibility --- deepspeed/utils/zero_to_fp32.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 44af82840553..e1e761099eae 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -202,7 +202,7 @@ def parse_optim_states(files, ds_checkpoint_dir): return zero_stage, world_size, fp32_flat_groups -def _get_param_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): """ Returns fp32 state_dict reconstructed from ds checkpoint @@ -217,10 +217,10 @@ def _get_param_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') if zero_stage <= 2: - return _get_param_state_dict_from_zero2_checkpoint(ds_checkpoint_dir, zero_model_states, + return _get_fp32_state_dict_from_zero2_checkpoint(ds_checkpoint_dir, zero_model_states, exclude_frozen_parameters) elif zero_stage == 3: - return _get_param_state_dict_from_zero3_checkpoint(ds_checkpoint_dir, zero_model_states, + return _get_fp32_state_dict_from_zero3_checkpoint(ds_checkpoint_dir, zero_model_states, exclude_frozen_parameters) @@ -343,7 +343,7 @@ def _consolidate_ucp_checkpoints(args, state_dict, slice_shapes): state_dict[param] = weight['param'] -def _get_param_state_dict_from_zero2_checkpoint(ds_checkpoint_dir, zero_model_states, exclude_frozen_parameters): +def _get_fp32_state_dict_from_zero2_checkpoint(ds_checkpoint_dir, zero_model_states, exclude_frozen_parameters): state_dict = OrderedDict() @@ -477,7 +477,7 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") -def _get_param_state_dict_from_zero3_checkpoint(ds_checkpoint_dir, zero_model_states, exclude_frozen_parameters): +def _get_fp32_state_dict_from_zero3_checkpoint(ds_checkpoint_dir, zero_model_states, exclude_frozen_parameters): state_dict = OrderedDict() model_files = _get_model_state_files(ds_checkpoint_dir) @@ -520,7 +520,7 @@ def _get_param_state_dict_from_zero3_checkpoint(ds_checkpoint_dir, zero_model_st return state_dict -def get_param_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32/fp16/bf16 consolidated state_dict that can be loaded with ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example @@ -567,7 +567,7 @@ def get_param_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_ if not os.path.isdir(ds_checkpoint_dir): raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") - return _get_param_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, @@ -602,7 +602,7 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, print('If you want to use `max_shard_size`, please `pip install huggingface_hub`') raise - state_dict = get_param_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) # Shard the model if it is too big. weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin" @@ -670,7 +670,7 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): """ logger.info(f"Extracting fp32 weights") - state_dict = get_param_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) logger.info(f"Overwriting model with fp32 weights") model = model.cpu() From 7c55f33dc8d47d12763c8d5752eb364327712289 Mon Sep 17 00:00:00 2001 From: Xinyu Lian Date: Wed, 23 Oct 2024 10:22:32 -0500 Subject: [PATCH 3/4] nit --- deepspeed/utils/zero_to_fp32.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index e1e761099eae..b421fd45cfe4 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -154,7 +154,7 @@ def parse_optim_states(files, ds_checkpoint_dir): total_files = len(files) state_dicts = [] for f in files: - state_dict = torch.load(f, map_location='cpu') + state_dict = torch.load(f, map_location=device) # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights # and also handle the case where it was already removed by another helper script state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) From ff1083c8fc964e48a01d4a3880b8af67dfffdd62 Mon Sep 17 00:00:00 2001 From: Xinyu Lian Date: Fri, 25 Oct 2024 05:15:51 +0000 Subject: [PATCH 4/4] fix formatting issue --- deepspeed/utils/zero_to_fp32.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index b421fd45cfe4..f864d333ec0b 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -218,10 +218,10 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_ if zero_stage <= 2: return _get_fp32_state_dict_from_zero2_checkpoint(ds_checkpoint_dir, zero_model_states, - exclude_frozen_parameters) + exclude_frozen_parameters) elif zero_stage == 3: return _get_fp32_state_dict_from_zero3_checkpoint(ds_checkpoint_dir, zero_model_states, - exclude_frozen_parameters) + exclude_frozen_parameters) def _zero2_merge_frozen_params(state_dict, zero_model_states):