Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ds-specific module id to avoid conflicts #6847

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
24 changes: 13 additions & 11 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _start_of_forward_hook(module, *args):
self.module.register_forward_pre_hook(_start_of_forward_hook)

#likely one of them should be enough but just to be safe
self._register_hooks_recursively(self.module)
self._register_deepspeed_module(self.module)

# Add top module to stack trace
global FWD_MODULE_STACK
Expand All @@ -269,19 +269,19 @@ def mark_persistent_parameters(self, param_threshold, model_threshold):

return persistent_params

def _register_hooks_recursively(self, module, count=[0]):
def _register_deepspeed_module(self, module, count=[0]):
my_count = count[0]
module.id = my_count
module.ds_id = my_count

#print(f"{module.__class__} : {module.id}")
#print(f"{module.__class__} : {module.ds_id}")

if z3_leaf_module(module):
for param in module.parameters():
param.ds_z3_leaf_module = module
else:
for child in module.children():
count[0] = count[0] + 1
self._register_hooks_recursively(child, count=count)
self._register_deepspeed_module(child, count=count)

@instrument_w_nvtx
def _pre_forward_module_hook(module, *args):
Expand Down Expand Up @@ -466,14 +466,16 @@ def pre_sub_module_forward_function(self, sub_module):

@torch.no_grad()
def post_sub_module_forward_function(self, sub_module):
see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False)
see_memory_usage(
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
force=False)

param_coordinator = self.get_param_coordinator()
param_coordinator.release_sub_module(sub_module)

see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release",
force=False)
see_memory_usage(
f"After sub module function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
force=False)

@torch.no_grad()
def pre_sub_module_backward_function(self, sub_module):
Expand All @@ -488,13 +490,13 @@ def pre_sub_module_backward_function(self, sub_module):
def post_sub_module_backward_function(self, sub_module):
# assert sub_module.training, "backward pass is invalid for module in evaluation mode"
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} before release",
force=False)

self.get_param_coordinator().release_sub_module(sub_module)

see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.ds_id} after release",
force=False)

def _set_z3_leaf_modules_by_threshold(self, module, zero_module_granularity_threshold):
Expand Down
24 changes: 12 additions & 12 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,17 @@ def trace_prologue(self, sub_module: Module) -> None:
# sub_module must match expectation else invalidate trace cache
if len(self.__submodule_order) <= self.__step_id:
print_rank_0(
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.id}: "
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.ds_id}: "
f"cache has only {len(self.__submodule_order)} modules",
force=True)
self._invalidate_trace()
return

if sub_module != self.__submodule_order[self.__step_id]:
expected_module_id = self.__submodule_order[self.__step_id].id
expected_module_id = self.__submodule_order[self.__step_id].ds_id
print_rank_0(
f"Invalidate trace cache @ step {self.__step_id}: "
f"expected module {expected_module_id}, but got module {sub_module.id}",
f"expected module {expected_module_id}, but got module {sub_module.ds_id}",
force=True)
self._invalidate_trace()

Expand All @@ -199,7 +199,7 @@ def record_module(self, sub_module: Module) -> None:
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")

self.__submodule_order.append(sub_module)
self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id)
self.__step_id_module_fetched_for[sub_module.ds_id].append(self.__step_id)

def record_parameters(self, sub_module: Module) -> None:
if is_compiling():
Expand All @@ -208,7 +208,7 @@ def record_parameters(self, sub_module: Module) -> None:
if not self.is_record_trace():
raise RuntimeError(f"attempted to record trace when status = {self.__trace_mode}")

step_id = self.__step_id_module_fetched_for[sub_module.id].popleft()
step_id = self.__step_id_module_fetched_for[sub_module.ds_id].popleft()
for param in sorted(set(iter_params(sub_module, recurse=z3_leaf_module(sub_module))), key=lambda p: p.ds_id):
self.__param_order.append(__class__.__ParamInTrace(param=param, step_id_last_used_at=step_id))

Expand All @@ -228,7 +228,7 @@ def reset_step(self) -> None:

if not self.is_complete_trace(): # not self.trace_complete:
# Make sure that recorded submodule orders are identical across ranks
assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order])
assert_ints_same_as_other_ranks([m.ds_id for m in self.__submodule_order])

if self.is_record_trace():
# Successfully recorded a trace
Expand All @@ -241,7 +241,7 @@ def reset_step(self) -> None:
self.__param_order = tuple(self.__param_order) # freeze
self.__trace_mode = ZeRoTraceMode.COMPLETE
print_rank_0(
f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.id for m in self.__submodule_order]}",
f"completed record trace of {len(self.__submodule_order)} sub modules: {[m.ds_id for m in self.__submodule_order]}",
force=False)
else:
# Enable trace recording for next forward/backward pass
Expand Down Expand Up @@ -284,7 +284,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
"""
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(
f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))]} "
f"{self.__step_id}: M{current_submodule.ds_id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule, recurse=z3_leaf_module(current_submodule))]} "
+ str({
"avail": f"{self.__n_available_params:.1e}",
"queue_sz": f"{len(self.__param_queue or [])}",
Expand All @@ -297,7 +297,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:

if fetch_numel > 0:
event_name = __class__.FORWARD_FETCH_SUBMIT if forward else __class__.BACKWARD_FETCH_SUBMIT
self._dump_param_ids(event_name, current_submodule.id,
self._dump_param_ids(event_name, current_submodule.ds_id,
[p.ds_id for p in params_to_fetch if p.ds_status == ZeroParamStatus.NOT_AVAILABLE])
self.__profiler.start_event(event_name)
# kick off all gather for params in the immediately required submodule
Expand All @@ -314,7 +314,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
fast_fetch = self.fast_sharding_for_leaf_module and z3_leaf_module(current_submodule)
# wait for parameters in the immediately needed submodule to become available
for param in params_to_fetch:
param.ds_active_sub_modules.add(current_submodule.id)
param.ds_active_sub_modules.add(current_submodule.ds_id)
if logger.isEnabledFor(logging.DEBUG):
debug_rank0(f"-wait: {param.ds_summary()}")
if param in self.__inflight_param_registry:
Expand Down Expand Up @@ -358,7 +358,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None:
if discarded_from_prefetch_queue != params_not_already_fetched:
raise RuntimeError(
f"tracing error at step {self.__step_id}: \n"
f"module id: {current_submodule.id}, training: {current_submodule.training}\n"
f"module id: {current_submodule.ds_id}, training: {current_submodule.training}\n"
f"expected the next {len(params_not_already_fetched)} parameters in the "
f"parameter fetch queue to be {tuple(p.ds_summary(use_debug_name=True) for p in params_not_already_fetched)} \n"
f"but got \n {tuple(p.ds_summary(use_debug_name=True) for p in discarded_from_prefetch_queue)}.")
Expand Down Expand Up @@ -425,7 +425,7 @@ def release_sub_module(self, submodule: Module) -> None:
empty_buffer = torch.empty(1, device=get_accelerator().current_device())

for param in iter_params(submodule, recurse=z3_leaf_module(submodule)):
param.ds_active_sub_modules.discard(submodule.id)
param.ds_active_sub_modules.discard(submodule.ds_id)
if param.ds_id in params_to_release and not param.is_external_param:
self.__release_param(param, free_data)
if not free_data:
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/runtime/zero/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,3 +1673,37 @@ def test(self, prefetch_ratio, zero_stage=3):
with torch.no_grad():
for batch in data_loader:
loss = model(batch[0], batch[1])


# Avoid overwriting client module id
# https://github.com/microsoft/DeepSpeed/issues/6772
class TestZero3ClientModuleID(DistributedTest):
world_size = 2

def test_client_module_id(self):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
},
"zero_optimization": {
"stage": 3
},
}

class MyModel(torch.nn.Module):

def __init__(self):
super().__init__()
self.id = 3 # ID arbitrary client usage, e.g. GPU placement
self.fc = Linear(128, 128)

def forward(self, x):
return self.fc(x)

model = MyModel()
pre_init_m_id = model.id
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
post_init_m_id = model.id
assert pre_init_m_id == post_init_m_id
Loading