diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 0780bd75..1b26a28f 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -837,13 +837,7 @@ def __init__( self.gradient_accumulation_fusion = config.gradient_accumulation_fusion self.sequence_parallel = config.sequence_parallel if self.sequence_parallel and not self.input_is_parallel: - # raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`") - print( - 'WARNING: To enable `sequence_parallel`', - '`input_is_parallel` must be `True ', - flush=True, - ) - self.input_is_parallel = True + raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`") # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py index 2eaee70e..c71859f0 100644 --- a/megatron/core/transformer/mlp.py +++ b/megatron/core/transformer/mlp.py @@ -53,6 +53,7 @@ def __init__( ffn_hidden_size, config=self.config, init_method=self.config.init_method, + gather_output=False, bias=self.config.add_bias_linear, skip_bias_add=True, is_expert=is_expert, @@ -75,6 +76,7 @@ def glu(x): config=self.config, init_method=self.config.output_layer_init_method, bias=self.config.add_bias_linear, + input_is_parallel=True, skip_bias_add=True, is_expert=is_expert, )