Skip to content

Commit

Permalink
fix: remove mark-time checking for non-existence of the flag as DeepS…
Browse files Browse the repository at this point in the history
…peedEngine propagates flag from the internal model
  • Loading branch information
Essoz committed Dec 16, 2024
1 parent 238ba1f commit dc81325
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
10 changes: 5 additions & 5 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,18 @@ def _parse_version(version_str):

def _mark_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
"""Mark a trainobj as initialized by setting the ds_is_inited attribute to True."""
# we shouldn't hit the assert below, but just in case
assert not hasattr(
trainobj, 'ds_is_inited'
), "Model has already been initialized, please make sure to only call deepspeed.initialize on a model once."
if hasattr(trainobj, 'ds_is_inited'):
assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once."
return

trainobj.ds_is_inited = True


def _is_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]):
"""Check if a trainobj has been initialized by checking the ds_is_inited attribute."""
if hasattr(trainobj, 'ds_is_inited'):
# we shouldn't hit the assert below, but just in case
assert trainobj.ds_is_inited, "Not expecting the model has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once."
assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once."
return True
return False

Expand Down
33 changes: 26 additions & 7 deletions tests/unit/runtime/test_ds_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,17 +445,14 @@ def test_no_repeated_init(self):
hidden_dim = 10
model = SimpleModel(hidden_dim)
client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model = SimpleModel()
# Initialize DeepSpeed configurations for fp16
config_dict = {'train_batch_size': 1}

client_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
# Initialize DeepSpeed engine
_assert_trainobjs_not_inited(model=model, optimizer=client_optimizer, lr_scheduler=None)
model_engine, optim, dataloader, scheduler = deepspeed.initialize(model=model,
optimizer=client_optimizer,
config_params=config_dict)
model_engine, optim, _, _ = deepspeed.initialize(model=model,
optimizer=client_optimizer,
config_params=config_dict)

# arguments should be marked as initialized now
assert _is_initialized(model), "Client model should be marked as initialized"
Expand All @@ -464,7 +461,6 @@ def test_no_repeated_init(self):
# return values should also be marked as initialized
assert _is_initialized(model_engine), "Model engine should be marked as initialized"
assert _is_initialized(optim), "Optimizer should be marked as initialized"
assert _is_initialized(scheduler), "Scheduler should be marked as initialized"

exception_raised = False
try:
Expand All @@ -473,3 +469,26 @@ def test_no_repeated_init(self):
exception_raised = True

assert exception_raised, "Repeated initialization should raise an exception"

exception_raised = False
try:
deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict)
except ValueError:
exception_raised = True

assert exception_raised, "Initialization on ds types should raise an exception"

exception_raised = False
try:
deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict)
except ValueError:
exception_raised = True

assert exception_raised, "Initialization on ds types should raise an exception"

exception_raised = False
try:
deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict)
except ValueError:
exception_raised = True
assert exception_raised, "Initialization on ds types should raise an exception"

0 comments on commit dc81325

Please sign in to comment.