Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When applying DAS to CLIP, there is an issue where the gradient becomes NaN. #11

Open
yunsangju opened this issue Jul 19, 2024 · 0 comments

Comments

@yunsangju
Copy link

Hello,

I am applying the DAS method to CLIP. When calculating the importance, the text model generates gradients well, but the vision model mostly produces NaN values. The units for calculating importance are placed in the self_attn and mlp of the CLIPEncoderLayer. The CLIPEncoderLayer is used identically for both the text and vision models.

I have declared the masks as follows:

class CLIPEncoder(nn.Module):
    """
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`CLIPEncoderLayer`].

    Args:
        config: CLIPConfig
    """

    def __init__(self, config: CLIPConfig):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.self_attn_mask = torch.ones(config.num_hidden_layers, config.hidden_size, dtype=torch.float16)
        self.self_attn_mask.requires_grad_(True)
        self.mlp_mask = torch.ones(config.num_hidden_layers, config.hidden_size, dtype=torch.float16)
        self.mlp_mask.requires_grad_(True)
        
        self.gradient_checkpointing = False

I have implemented it to operate in the CLIPEncoderLayer as follows:

residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states, attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
        )
        self_attn_mask = self_attn_mask.to(hidden_states.device)
        hidden_states = hidden_states * self_attn_mask
        hidden_states = residual + hidden_states
        mlp_mask = mlp_mask.to(hidden_states.device)
        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states *= mlp_mask
        hidden_states = residual + hidden_states
        outputs = (hidden_states,)

I would like to inquire if you have experienced the same phenomenon or if the implementation is incorrect.

Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant