Skip to content

Commit

Permalink
fix(model): fix model forward when checkpoint=true (#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwiacx authored May 14, 2024
1 parent c4ede3a commit b0b23bb
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 13 deletions.
11 changes: 8 additions & 3 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from internlm.model.modules.mlp import new_feed_forward
from internlm.model.modules.norm import new_layer_norm
from internlm.model.utils import (
convert_attn_args_to_kwargs,
convert_attn_kwargs_to_args,
internlm1_mha_pre_load_convert,
internlm1_mha_save_convert,
)
Expand Down Expand Up @@ -162,11 +164,13 @@ def reset_parameters(self):

def forward(self, hidden_states, **kwargs):
if self.checkpoint and self.training:
return activation_checkpoint(self._forward, False, hidden_states, **kwargs)
# NOTICE: activation_checkpiont do not support kwargs when use_reentrant = True.
args = convert_attn_kwargs_to_args(kwargs)
return activation_checkpoint(self._forward, False, hidden_states, *args)
else:
return self._forward(hidden_states, **kwargs)

def _forward(self, hidden_states=None, **kwargs):
def _forward(self, hidden_states, *args, **kwargs):
r"""Pass the input through the encoder layer.
Args:
Expand All @@ -190,7 +194,8 @@ def _dropout_and_norm_attn(_hidden_states):
if self.residual_in_fp32:
residual = residual.to(torch.float32)

hidden_states = self.mixer(hidden_states, **kwargs)
mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs)
hidden_states = self.mixer(hidden_states, **mixer_kwargs)

def _dropout_and_norm_ffn(_residual, _hidden_states):
_dropped = self.dropout2(_hidden_states)
Expand Down
13 changes: 10 additions & 3 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from internlm.model.modules.mha import GQA
from internlm.model.modules.mlp import new_feed_forward
from internlm.model.modules.norm import new_layer_norm
from internlm.model.utils import (
convert_attn_args_to_kwargs,
convert_attn_kwargs_to_args,
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger

Expand Down Expand Up @@ -197,11 +201,13 @@ def reset_parameters(self):

def forward(self, hidden_states, residual=None, **kwargs):
if self.checkpoint and self.training:
return activation_checkpoint(self._forward, False, hidden_states, residual, **kwargs)
# NOTICE: activation_checkpiont do not support kwargs when use_reentrant = True.
args = convert_attn_kwargs_to_args(kwargs)
return activation_checkpoint(self._forward, False, hidden_states, residual, *args)
else:
return self._forward(hidden_states, residual, **kwargs)

def _forward(self, hidden_states=None, residual=None, **kwargs):
def _forward(self, hidden_states, residual, *args, **kwargs):
r"""Pass the input through the encoder layer.
Args:
Expand All @@ -227,7 +233,8 @@ def _dropout_and_norm_attn(_residual, _hidden_states):
if self.residual_in_fp32:
residual = residual.to(torch.float32)

hidden_states = self.attention(hidden_states, **kwargs)
attn_kwargs = convert_attn_args_to_kwargs(args, kwargs)
hidden_states = self.attention(hidden_states, **attn_kwargs)

if not isinstance(self.feed_forward, nn.Identity):
if not self.fused_dropout_add_ln:
Expand Down
13 changes: 10 additions & 3 deletions internlm/model/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from internlm.model.modules.mha import GQA
from internlm.model.modules.mlp import new_feed_forward
from internlm.model.modules.norm import new_layer_norm
from internlm.model.utils import (
convert_attn_args_to_kwargs,
convert_attn_kwargs_to_args,
)
from internlm.solver.activation_checkpoint import activation_checkpoint
from internlm.utils.logger import get_logger

Expand Down Expand Up @@ -189,11 +193,13 @@ def reset_parameters(self):

def forward(self, hidden_states, residual=None, **kwargs):
if self.checkpoint and self.training:
return activation_checkpoint(self._forward, False, hidden_states, residual, **kwargs)
# NOTICE: activation_checkpiont do not support kwargs when use_reentrant = True.
args = convert_attn_kwargs_to_args(kwargs)
return activation_checkpoint(self._forward, False, hidden_states, residual, *args)
else:
return self._forward(hidden_states, residual, **kwargs)

def _forward(self, hidden_states=None, residual=None, **kwargs):
def _forward(self, hidden_states, residual, *args, **kwargs):
r"""Pass the input through the encoder layer.
Args:
Expand All @@ -219,7 +225,8 @@ def _dropout_and_norm_attn(_residual, _hidden_states):
if self.residual_in_fp32:
residual = residual.to(torch.float32)

hidden_states = self.attention(hidden_states, **kwargs)
attn_kwargs = convert_attn_args_to_kwargs(args, kwargs)
hidden_states = self.attention(hidden_states, **attn_kwargs)

if not isinstance(self.feed_forward, nn.Identity):
if not self.fused_dropout_add_ln:
Expand Down
11 changes: 8 additions & 3 deletions internlm/model/modeling_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from internlm.model.modules.norm import new_layer_norm
from internlm.model.moe.moe import MoE
from internlm.model.utils import (
convert_attn_args_to_kwargs,
convert_attn_kwargs_to_args,
internlm1_mha_pre_load_convert,
internlm1_mha_save_convert,
)
Expand Down Expand Up @@ -179,11 +181,13 @@ def reset_parameters(self):
def forward(self, hidden_states, **kwargs):
if self.checkpoint and self.training:
# TODO: check whether this will be affected by moe
return activation_checkpoint(self._forward, False, hidden_states, **kwargs)
# NOTICE: activation_checkpiont do not support kwargs when use_reentrant = True.
args = convert_attn_kwargs_to_args(kwargs)
return activation_checkpoint(self._forward, False, hidden_states, *args)
else:
return self._forward(hidden_states, **kwargs)

def _forward(self, hidden_states=None, **kwargs):
def _forward(self, hidden_states, *args, **kwargs):
r"""Pass the input through the encoder layer.
Args:
Expand All @@ -207,7 +211,8 @@ def _dropout_and_norm_attn(_hidden_states):
if self.residual_in_fp32:
residual = residual.to(torch.float32)

hidden_states = self.mixer(hidden_states, **kwargs)
mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs)
hidden_states = self.mixer(hidden_states, **mixer_kwargs)

def _dropout_and_norm_ffn(_residual, _hidden_states):
_dropped = self.dropout2(_hidden_states)
Expand Down
33 changes: 32 additions & 1 deletion internlm/model/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from typing import Any, Dict, List

from internlm.model.modules.mha import MHA

Expand All @@ -20,3 +20,34 @@ def internlm1_mha_save_convert(

if f"{prefix}wqkv.bias" in state_dict:
state_dict[f"{prefix}Wqkv.bias"] = state_dict.pop(f"{prefix}wqkv.bias")


def convert_attn_kwargs_to_args(kwargs) -> List[Any]:
inference_params = kwargs.get("inference_params", None)
cu_seqlens = kwargs.get("cu_seqlens", None)
indexes = kwargs.get("indexes", None)
max_seqlen = kwargs.get("max_seqlen", None)

return (inference_params, cu_seqlens, indexes, max_seqlen)


def convert_attn_args_to_kwargs(args, kwargs) -> Dict[str, Any]:
if len(args) == 0:
return kwargs

assert len(args) == 4, "args must be generate by convert_attn_kwargs_to_args function"

if args[0] is not None:
assert "inference_params" not in kwargs, "repeated 'inference_params' argument exists both in args and kwargs"
kwargs["inference_params"] = args[0]
if args[1] is not None:
assert "cu_seqlens" not in kwargs, "repeated 'cu_seqlens' argument exists both in args and kwargs"
kwargs["cu_seqlens"] = args[1]
if args[2] is not None:
assert "indexes" not in kwargs, "repeated 'indexes' argument exists both in args and kwargs"
kwargs["indexes"] = args[2]
if args[3] is not None:
assert "max_seqlen" not in kwargs, "repeated 'max_seqlen' argument exists both in args and kwargs"
kwargs["max_seqlen"] = args[3]

return kwargs

0 comments on commit b0b23bb

Please sign in to comment.