From 3573858e7ce2c723b8c43231c6c6b0cf97dca2fc Mon Sep 17 00:00:00 2001 From: Nir Sonnenschein Date: Mon, 30 Dec 2024 20:53:41 +0200 Subject: [PATCH] Change compile for pipeline module torch.compile (#6478) We have encountered and issue with torch.compile and the pipeline module. modifying a member of the module (micro_offset) during the forward function will cause torch compile to restart the analysis and treat the module as dynamic. In order to bypass this issue without significantly changing the way the pipeline module works we propose to compile only the layers in the pipeline module instead of the forward function of pipeline module. this will bypass the issue and should still give most of the benefit of torch compiling the pipeline module while avoiding the issue. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/pipe/module.py | 8 ++++++++ tests/unit/pipe/test_pipe_module.py | 8 ++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 31fec30be788..9fbd91f750a9 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -662,3 +662,11 @@ def get_additional_losses(self): Return a dictionary of {"loss name": loss_value} or None if no additional losses. """ return None + + def compile(self, *args, **kwargs): + for idx, layer in enumerate(self.forward_funcs): + if isinstance(layer, nn.Module): + layer.compile(*args, **kwargs) + else: + new_layer = torch.compile(layer, *args, **kwargs) + self.forward_funcs[idx] = new_layer diff --git a/tests/unit/pipe/test_pipe_module.py b/tests/unit/pipe/test_pipe_module.py index 05c6a82ef55a..2a8a4b9b7d82 100644 --- a/tests/unit/pipe/test_pipe_module.py +++ b/tests/unit/pipe/test_pipe_module.py @@ -60,9 +60,12 @@ def batch_input(): class TestPipeModuleSequential(DistributedTest): world_size = 2 + # needs to be set for torch.compile: running torch.compile with daemonic process causes an error + non_daemonic_procs = True @pytest.mark.parametrize("activation_checkpoints", [False, True]) - def test(self, sequential_model, simple_config, batch_input, activation_checkpoints): + @pytest.mark.parametrize("use_compile", [False, True]) + def test(self, sequential_model, simple_config, batch_input, activation_checkpoints, use_compile): base_model = copy.deepcopy(sequential_model) base_input = batch_input.clone().detach() base_output = base_model(base_input) @@ -71,7 +74,8 @@ def test(self, sequential_model, simple_config, batch_input, activation_checkpoi pipe_model = copy.deepcopy(sequential_model) pipe_model = PipelineModule(layers=pipe_model, num_stages=2) - + if (use_compile): + pipe_model.compile() # Ensure all parameters are accounted for. my_params = sum(p.numel() for p in pipe_model.parameters()) total_pipe_params = torch.LongTensor([my_params]).to(get_accelerator().device_name())