Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MambaSplitConv1dScanCombinedFnBackward' returned nan values in its 0th output. #672

Open
TheMakerOfWorlds opened this issue Jan 15, 2025 · 0 comments

Comments

@TheMakerOfWorlds
Copy link

TheMakerOfWorlds commented Jan 15, 2025

Been debugging this and it seems that when layers > 4 i will get nan values on MambaSplitConv1dScanCombinedFnBackward.

anyone have any ideas on why or how to fix would be much appreciated!

After the first run of the model, I get my loss, then the error will be thrown on loss.backward() when doing torch.autograd.detect_anomaly()


class Mamba2LMModel(nn.Module, GenerationMixin):
    """
    Mamba2-based autoregressive model that generates sequences from a given context.
    """

    def __init__(
        self,
        d_model=1024,
        n_layer=8,
        d_intermediate=2048,
        vocab_size=3072,
        norm_epsilon=1e-5,
        rms_norm=False,
        fused_add_norm=False,
        residual_in_fp32=False,
        device=None,
        dtype=None,
    ):
        super().__init__()
        factory_kwargs = {"device": device, "dtype": dtype}

        # Stack of Mamba2 blocks
        self.layers = nn.ModuleList(
            [
                create_mamba2_block(
                    d_model=d_model,
                    d_intermediate=d_intermediate,
                    norm_epsilon=norm_epsilon,
                    rms_norm=rms_norm,
                    fused_add_norm=fused_add_norm,
                    residual_in_fp32=residual_in_fp32,
                    layer_idx=i,
                    **factory_kwargs,
                )
                for i in range(n_layer)
            ]
        )

        # Final normalization layer
        self.norm_f = nn.LayerNorm(d_model, eps=norm_epsilon, **factory_kwargs)

        # Linear head for logits
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)

        # Trainable embedding for tokens
        self.token_embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)

    def forward(self, context, new_tokens=None):
        """
        Forward pass with context and optional new token embeddings.

        Args:
            context: Tensor of shape (batch_size, context_length, d_model).
            new_tokens: Tensor of shape (batch_size, num_new_tokens) or None.

        Returns:
            dict: Contains logits for the new tokens.
        """

        # Debugging NaN values REMOVE LATER
        print("Input context stats:")
        print(f"Mean: {context.mean().item()}, Std: {context.std().item()}")
        print(f"Max: {context.max().item()}, Min: {context.min().item()}")
        if torch.isnan(context).any():
            print("NaN detected in input context!")

        # Process context through the Mamba2 layers without generating logits
        hidden_states = context  # (batch_size, context_length, d_model)
        residual = None

        for i, layer in enumerate(self.layers):
            hidden_states, residual = layer(hidden_states, residual)

            # Debugging NaN values REMOVE LATER
            if torch.isnan(hidden_states).any() or (
                residual is not None and torch.isnan(residual).any()
            ):
                print(f"NaN detected at layer {i}")
                print(f"Hidden states: {hidden_states}")
                if residual is not None:
                    print(f"Residual: {residual}")
                break

        if residual is not None:
            hidden_states = hidden_states + residual

        processed_context = self.norm_f(hidden_states)  # Final processed context

        if new_tokens is None:
            # Return the processed context without logits if no new tokens are provided
            return {"processed_context": processed_context}

        # Debugging NaN values REMOVE LATER
        print(f"Input IDs Stats: Min={new_tokens.min()}, Max={new_tokens.max()}")
        assert (new_tokens >= 0).all() and (
            new_tokens < self.token_embedding.num_embeddings
        ).all(), "Invalid input_ids detected!"
        assert not torch.isnan(new_tokens).any(), "NaN detected in input_ids!"

        # Embed new tokens and concatenate with the processed context
        new_token_embeddings = self.token_embedding(
            new_tokens
        )  # (batch_size, num_new_tokens, d_model)

        # Debugging NaN values REMOVE LATER -----
        print(
            f"Embedding Output Stats: Mean: {hidden_states.mean().item()}, Std: {hidden_states.std().item()}, "
            f"Max: {hidden_states.max().item()}, Min: {hidden_states.min().item()}"
        )
        if torch.isnan(hidden_states).any():
            print("NaN detected in embedding output!")
        # Debugging NaN values REMOVE LATER -----

        combined_input = torch.cat([processed_context, new_token_embeddings], dim=1)

        # Pass combined input through layers to generate logits for the new tokens
        hidden_states = combined_input
        hidden_states = torch.clamp(hidden_states, min=-5.0, max=5.0)

        residual = None

        for layer in self.layers:
            hidden_states, residual = layer(hidden_states, residual)

        if residual is not None:
            hidden_states = hidden_states + residual

        final_states = self.norm_f(hidden_states)

        # Generate logits only for the new tokens
        logits = self.lm_head(
            final_states[:, -new_token_embeddings.size(1) :, :]
        )  # Only last num_new_tokens positions
        # token_embeddings = final_states[:, -new_token_embeddings.size(1) :, :]

        return {
            "logits": logits,
        }  # (batch_size, num_new_tokens, vocab_size)


My initialization is

def init_mamba2_weights(
    module,
    dt_bias_init: float = -5.0,
    a_log_init: float = 0.0,
    conv_gain: float = 1.0,
    std_linear: float = 0.02,
    std_embedding: float = 0.02,
):
    """
    Enhanced initialization for Mamba2-based models to prevent NaN values.
    """
    # ----- LayerNorm -----
    if isinstance(module, nn.LayerNorm):
        nn.init.ones_(module.weight)
        nn.init.zeros_(module.bias)

    # ----- Embeddings -----
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, mean=0.0, std=std_embedding)

    # ----- Linear Layers -----
    elif isinstance(module, nn.Linear):
        nn.init.xavier_uniform_(module.weight, gain=1.0)
        if module.bias is not None:
            nn.init.zeros_(module.bias)

    # ----- Conv1D Layers -----
    elif isinstance(module, nn.Conv1d):
        fan_in = module.in_channels * module.kernel_size[0]
        fan_out = module.out_channels * module.kernel_size[0]
        limit = conv_gain * math.sqrt(6.0 / (fan_in + fan_out))
        nn.init.uniform_(module.weight, -limit, limit)
        if module.bias is not None:
            nn.init.zeros_(module.bias)

    # ----- Mamba2-Specific Parameters -----
    if hasattr(module, "dt_bias"):
        with torch.no_grad():
            module.dt_bias.fill_(dt_bias_init)

    if hasattr(module, "A_log"):
        with torch.no_grad():
            module.A_log.fill_(a_log_init)

    if hasattr(module, "D"):
        with torch.no_grad():
            module.D.fill_(1.0)

    # Debugging: Print if any NaNs exist after initialization
    for name, param in module.named_parameters(recurse=False):
        if param.requires_grad and torch.isnan(param).any():
            print(f"Warning: NaN detected in {name} after initialization.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant