diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 31fec30be788..9fbd91f750a9 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -662,3 +662,11 @@ def get_additional_losses(self): Return a dictionary of {"loss name": loss_value} or None if no additional losses. """ return None + + def compile(self, *args, **kwargs): + for idx, layer in enumerate(self.forward_funcs): + if isinstance(layer, nn.Module): + layer.compile(*args, **kwargs) + else: + new_layer = torch.compile(layer, *args, **kwargs) + self.forward_funcs[idx] = new_layer diff --git a/tests/unit/pipe/test_pipe_module.py b/tests/unit/pipe/test_pipe_module.py index 05c6a82ef55a..2a8a4b9b7d82 100644 --- a/tests/unit/pipe/test_pipe_module.py +++ b/tests/unit/pipe/test_pipe_module.py @@ -60,9 +60,12 @@ def batch_input(): class TestPipeModuleSequential(DistributedTest): world_size = 2 + # needs to be set for torch.compile: running torch.compile with daemonic process causes an error + non_daemonic_procs = True @pytest.mark.parametrize("activation_checkpoints", [False, True]) - def test(self, sequential_model, simple_config, batch_input, activation_checkpoints): + @pytest.mark.parametrize("use_compile", [False, True]) + def test(self, sequential_model, simple_config, batch_input, activation_checkpoints, use_compile): base_model = copy.deepcopy(sequential_model) base_input = batch_input.clone().detach() base_output = base_model(base_input) @@ -71,7 +74,8 @@ def test(self, sequential_model, simple_config, batch_input, activation_checkpoi pipe_model = copy.deepcopy(sequential_model) pipe_model = PipelineModule(layers=pipe_model, num_stages=2) - + if (use_compile): + pipe_model.compile() # Ensure all parameters are accounted for. my_params = sum(p.numel() for p in pipe_model.parameters()) total_pipe_params = torch.LongTensor([my_params]).to(get_accelerator().device_name())