Skip to content

Commit

Permalink
fix hf internlm nan bug (#295)
Browse files Browse the repository at this point in the history
  • Loading branch information
sallyjunjun authored Aug 9, 2024
1 parent 137deb3 commit 708260f
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
4 changes: 2 additions & 2 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ class DistributedAttention(nn.Module):

def __init__(
self,
local_attention: nn.Module,
local_attention: Union[nn.Module, Callable],
sequence_process_group: dist.ProcessGroup,
) -> None:
super().__init__()
Expand Down Expand Up @@ -914,7 +914,7 @@ def _attetion_constructor(
return partial(_attetion_constructor, local_attn_cls=cls)


def auto_wrap_func_distributed_attention(func: Callable) -> Callable[[bool, Any, float], nn.Module]:
def auto_wrap_func_distributed_attention(func: Callable) -> Callable[..., Callable]:
"""
Wrap a local attention function to a distributed one, which will be used in the ISP parallelism.
"""
Expand Down
1 change: 0 additions & 1 deletion internlm/model/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,5 +1032,4 @@ def hf_q_k_v_with_cu_seqlens(
return_attn_probs=False,
causal=causal,
)
attn_output = attn_output.unsqueeze(0)
return attn_output

0 comments on commit 708260f

Please sign in to comment.