You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Been debugging this and it seems that when layers > 4 i will get nan values on MambaSplitConv1dScanCombinedFnBackward.
anyone have any ideas on why or how to fix would be much appreciated!
After the first run of the model, I get my loss, then the error will be thrown on loss.backward() when doing torch.autograd.detect_anomaly()
class Mamba2LMModel(nn.Module, GenerationMixin):
"""
Mamba2-based autoregressive model that generates sequences from a given context.
"""
def __init__(
self,
d_model=1024,
n_layer=8,
d_intermediate=2048,
vocab_size=3072,
norm_epsilon=1e-5,
rms_norm=False,
fused_add_norm=False,
residual_in_fp32=False,
device=None,
dtype=None,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
# Stack of Mamba2 blocks
self.layers = nn.ModuleList(
[
create_mamba2_block(
d_model=d_model,
d_intermediate=d_intermediate,
norm_epsilon=norm_epsilon,
rms_norm=rms_norm,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
layer_idx=i,
**factory_kwargs,
)
for i in range(n_layer)
]
)
# Final normalization layer
self.norm_f = nn.LayerNorm(d_model, eps=norm_epsilon, **factory_kwargs)
# Linear head for logits
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
# Trainable embedding for tokens
self.token_embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
def forward(self, context, new_tokens=None):
"""
Forward pass with context and optional new token embeddings.
Args:
context: Tensor of shape (batch_size, context_length, d_model).
new_tokens: Tensor of shape (batch_size, num_new_tokens) or None.
Returns:
dict: Contains logits for the new tokens.
"""
# Debugging NaN values REMOVE LATER
print("Input context stats:")
print(f"Mean: {context.mean().item()}, Std: {context.std().item()}")
print(f"Max: {context.max().item()}, Min: {context.min().item()}")
if torch.isnan(context).any():
print("NaN detected in input context!")
# Process context through the Mamba2 layers without generating logits
hidden_states = context # (batch_size, context_length, d_model)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(hidden_states, residual)
# Debugging NaN values REMOVE LATER
if torch.isnan(hidden_states).any() or (
residual is not None and torch.isnan(residual).any()
):
print(f"NaN detected at layer {i}")
print(f"Hidden states: {hidden_states}")
if residual is not None:
print(f"Residual: {residual}")
break
if residual is not None:
hidden_states = hidden_states + residual
processed_context = self.norm_f(hidden_states) # Final processed context
if new_tokens is None:
# Return the processed context without logits if no new tokens are provided
return {"processed_context": processed_context}
# Debugging NaN values REMOVE LATER
print(f"Input IDs Stats: Min={new_tokens.min()}, Max={new_tokens.max()}")
assert (new_tokens >= 0).all() and (
new_tokens < self.token_embedding.num_embeddings
).all(), "Invalid input_ids detected!"
assert not torch.isnan(new_tokens).any(), "NaN detected in input_ids!"
# Embed new tokens and concatenate with the processed context
new_token_embeddings = self.token_embedding(
new_tokens
) # (batch_size, num_new_tokens, d_model)
# Debugging NaN values REMOVE LATER -----
print(
f"Embedding Output Stats: Mean: {hidden_states.mean().item()}, Std: {hidden_states.std().item()}, "
f"Max: {hidden_states.max().item()}, Min: {hidden_states.min().item()}"
)
if torch.isnan(hidden_states).any():
print("NaN detected in embedding output!")
# Debugging NaN values REMOVE LATER -----
combined_input = torch.cat([processed_context, new_token_embeddings], dim=1)
# Pass combined input through layers to generate logits for the new tokens
hidden_states = combined_input
hidden_states = torch.clamp(hidden_states, min=-5.0, max=5.0)
residual = None
for layer in self.layers:
hidden_states, residual = layer(hidden_states, residual)
if residual is not None:
hidden_states = hidden_states + residual
final_states = self.norm_f(hidden_states)
# Generate logits only for the new tokens
logits = self.lm_head(
final_states[:, -new_token_embeddings.size(1) :, :]
) # Only last num_new_tokens positions
# token_embeddings = final_states[:, -new_token_embeddings.size(1) :, :]
return {
"logits": logits,
} # (batch_size, num_new_tokens, vocab_size)
My initialization is
def init_mamba2_weights(
module,
dt_bias_init: float = -5.0,
a_log_init: float = 0.0,
conv_gain: float = 1.0,
std_linear: float = 0.02,
std_embedding: float = 0.02,
):
"""
Enhanced initialization for Mamba2-based models to prevent NaN values.
"""
# ----- LayerNorm -----
if isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
# ----- Embeddings -----
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=std_embedding)
# ----- Linear Layers -----
elif isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight, gain=1.0)
if module.bias is not None:
nn.init.zeros_(module.bias)
# ----- Conv1D Layers -----
elif isinstance(module, nn.Conv1d):
fan_in = module.in_channels * module.kernel_size[0]
fan_out = module.out_channels * module.kernel_size[0]
limit = conv_gain * math.sqrt(6.0 / (fan_in + fan_out))
nn.init.uniform_(module.weight, -limit, limit)
if module.bias is not None:
nn.init.zeros_(module.bias)
# ----- Mamba2-Specific Parameters -----
if hasattr(module, "dt_bias"):
with torch.no_grad():
module.dt_bias.fill_(dt_bias_init)
if hasattr(module, "A_log"):
with torch.no_grad():
module.A_log.fill_(a_log_init)
if hasattr(module, "D"):
with torch.no_grad():
module.D.fill_(1.0)
# Debugging: Print if any NaNs exist after initialization
for name, param in module.named_parameters(recurse=False):
if param.requires_grad and torch.isnan(param).any():
print(f"Warning: NaN detected in {name} after initialization.")
The text was updated successfully, but these errors were encountered:
Been debugging this and it seems that when layers > 4 i will get nan values on MambaSplitConv1dScanCombinedFnBackward.
anyone have any ideas on why or how to fix would be much appreciated!
After the first run of the model, I get my loss, then the error will be thrown on loss.backward() when doing torch.autograd.detect_anomaly()
My initialization is
The text was updated successfully, but these errors were encountered: