Skip to content

Commit

Permalink
Add guards for user behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Sep 11, 2024
1 parent 844355b commit 87dba32
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 14 deletions.
7 changes: 4 additions & 3 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@
DeepSpeedSchedulerWrapper,
DummyOptim,
DummyScheduler,
get_active_deepspeed_plugin,
)

if is_megatron_lm_available():
Expand Down Expand Up @@ -251,7 +250,7 @@ def __init__(
gradient_accumulation_steps: int = 1,
cpu: bool = False,
dataloader_config: DataLoaderConfiguration | None = None,
deepspeed_plugin: DeepSpeedPlugin | None = None,
deepspeed_plugin: DeepSpeedPlugin | list[DeepSpeedPlugin] | None = None,
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
megatron_lm_plugin: MegatronLMPlugin | None = None,
rng_types: list[str | RNGType] | None = None,
Expand Down Expand Up @@ -564,6 +563,8 @@ def deepspeed_plugin(self):
If using multiple plugins, the first one will be the active one by default. Manually call `plugin.enable()` to
activate a different plugin.
If deepspeed is not enabled, this will return `None`.
"""
return self.state.deepspeed_plugin

Expand Down Expand Up @@ -1673,7 +1674,7 @@ def _prepare_deepspeed(self, *args):

ds_initialize = msamp_deepspeed.initialize

deepspeed_plugin = get_active_deepspeed_plugin(self.state)
deepspeed_plugin = self.deepspeed_plugin

is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args)
result = [
Expand Down
12 changes: 4 additions & 8 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,10 +950,7 @@ def initialized(self) -> bool:
def __repr__(self):
repr = PartialState().__repr__() + f"\nMixed precision type: {self.mixed_precision}\n"
if self.distributed_type == DistributedType.DEEPSPEED:
from accelerate.utils.deepspeed import get_active_deepspeed_plugin

active_plugin = get_active_deepspeed_plugin(self)
repr += f"ds_config: {active_plugin.deepspeed_config}\n"
repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n"
return repr

def _check_initialized(self, mixed_precision=None, cpu=None):
Expand Down Expand Up @@ -982,10 +979,7 @@ def use_fp16(self):
@property
def mixed_precision(self):
if self.distributed_type == DistributedType.DEEPSPEED:
from accelerate.utils.deepspeed import get_active_deepspeed_plugin

active_plugin = get_active_deepspeed_plugin(self)
config = active_plugin.deepspeed_config
config = self.deepspeed_plugin.deepspeed_config
if config.get("fp16", {}).get("enabled", False):
mixed_precision = "fp16"
elif config.get("bf16", {}).get("enabled", False):
Expand Down Expand Up @@ -1109,6 +1103,8 @@ def deepspeed_plugin(self):
If using multiple plugins, the first one will be the active one by default. Manually call `plugin.enable()` to
activate a different plugin.
If deepspeed is not enabled, this will return `None`.
"""
if self.distributed_type != DistributedType.DEEPSPEED:
return None
Expand Down
20 changes: 17 additions & 3 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,7 @@ def __post_init__(self):
self.zero3_init_flag = False
# NOTE: Set to False by default, will be set to `True` automatically if it's the first plugin passed
# to the `Accelerator`'s `deepspeed_plugin` param, *or* `plugin.enable()` is manually called
self.enabled = False
self._set_enabled(False)

# Ignore if it's already set
if self.enable_msamp and "msamp" not in self.deepspeed_config:
Expand Down Expand Up @@ -1298,10 +1298,24 @@ def enable(self):
for plugin in AcceleratorState().deepspeed_plugins:
if plugin is not self:
plugin.disable()
self.enabled = True
self._set_enabled(True)

def disable(self):
self.enabled = False
self._set_enabled(False)

def _set_enabled(self, value: bool):
"""
Private setter for the 'enabled' attribute.
"""
self._enabled = value

@property
def enabled(self):
return self._enabled

@enabled.setter
def enabled(self, value):
raise NotImplementedError("'enabled' can only be set through the 'enable()' method.")


@dataclass
Expand Down
11 changes: 11 additions & 0 deletions tests/deepspeed/test_deepspeed_multiple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,17 @@ def test_enable_disable(self):
assert ds_zero2.enabled
assert get_active_deepspeed_plugin(accelerator.state) == ds_zero2

def test_enable_disable_manually_set(self):
ds_zero2, _ = self.get_ds_plugins()
ds_zero2.enable()
with self.assertRaises(NotImplementedError):
ds_zero2.enabled = False
assert ds_zero2.enabled
ds_zero2.disable()
assert not ds_zero2.enabled
with self.assertRaises(NotImplementedError):
ds_zero2.enabled = True

def test_prepare_multiple_models(self):
ds_zero2, ds_zero3 = self.get_ds_plugins()
accelerator = Accelerator(deepspeed_plugin=[ds_zero2, ds_zero3])
Expand Down

0 comments on commit 87dba32

Please sign in to comment.