Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential Misalignment in MSDeformableAttention.forward Implementation for reference_points.shape[-1] == 2 #505

Open
Anchor1566 opened this issue Dec 3, 2024 · 2 comments
Assignees

Comments

@Anchor1566
Copy link

Anchor1566 commented Dec 3, 2024

First of all, thank you for your excellent work!

While studying the implementation, I encountered a potential mistake in the forward method of MSDeformableAttention, in the branch where reference_points.shape[-1] == 2. I suspect there might be a dimension misalignment during the computation of sampling_locations, in the following line:

if reference_points.shape[-1] == 2:
            offset_normalizer = torch.tensor(value_spatial_shapes)
            offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.num_levels, 1, 2)
            sampling_locations = reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2) + sampling_offsets / offset_normalizer

Here’s my reasoning:

1.The shape of reference_points after reshaping is (bs, Len_q, 1, self.num_levels, 1, 2).
2.The shape of sampling_offsets is (bs, Len_q, self.num_heads, sum(self.num_points_list), 2).
3.The shape of offset_normalizer is (1, 1, 1, self.num_levels, 1, 2).

Given these shapes, it seems that the reference_points and sampling_offsets / offset_normalizer tensors may not align correctly due to the lack of self.num_heads and sum(self.num_points_list) dimensions in reference_points.( they are misalignment in self.num_levels and self.num_levels in this case)

My Attempted Fix

To resolve this, I modified the code lightly to ensure shape compatibility by expanding reference_points:

def forward(self,
                query: torch.Tensor,
                reference_points: torch.Tensor,
                value: torch.Tensor,
                value_spatial_shapes: List[int],
                value_mask: torch.Tensor=None):
        """
        Args:
            query (Tensor): [bs, query_length, C]
            reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
                bottom-right (1, 1), including padding area
            value (Tensor): [bs, value_length, C]
            value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
            value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements

        Returns:
            output (Tensor): [bs, Length_{query}, C]
        """
        bs, Len_q = query.shape[:2]
        Len_v = value.shape[1]

        value = self.value_proj(value)
        if value_mask is not None:
            value = value * value_mask.to(value.dtype).unsqueeze(-1)

        value = value.reshape(bs, Len_v, self.num_heads, self.head_dim)

        sampling_offsets: torch.Tensor = self.sampling_offsets(query)
        sampling_offsets = sampling_offsets.reshape(bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2)

        attention_weights = self.attention_weights(query).reshape(bs, Len_q, self.num_heads, sum(self.num_points_list))
        attention_weights = F.softmax(attention_weights, dim=-1).reshape(bs, Len_q, self.num_heads, sum(self.num_points_list))

        if reference_points.shape[-1] == 2:
            offset_normalizer = torch.tensor(value_spatial_shapes).flip([1]).to(reference_points.device)
            offset_normalizer = offset_normalizer.reshape(1, 1, 1, self.num_levels, 1, 2)
            sampling_locations = reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2) \
                                + sampling_offsets / offset_normalizer
            
            sampling_locations = sampling_locations.reshape(bs, Len_q, self.num_heads, self.num_levels * self.num_points, 2)
        elif reference_points.shape[-1] == 4:
            sampling_locations = reference_points[:, :, None, :, None, :2] \
                                + sampling_offsets / self.num_points * reference_points[:, :, None, :, None, 2:] * 0.5
            
            sampling_locations = sampling_locations.reshape(bs, Len_q, self.num_heads, self.num_levels * self.num_points, 2)
        else:
            raise ValueError(
                "Last dim of reference_points must be 2 or 4, but get {} instead.".
                format(reference_points.shape[-1]))

        output = self.ms_deformable_attn_core(value, value_spatial_shapes, sampling_locations, attention_weights, self.num_points_list)
        output = self.output_proj(output)

        return output

With this modification, the broadcasting appears to work as expected, but I’m unsure if this fully aligns with the intended functionality.
Thank you again for your amazing work!

@lyuwenyu
Copy link
Owner

lyuwenyu commented Dec 5, 2024

Yes, you're right. The rtdetrv2 code didn't consider the case of reference_points.shape[-1] == 2.

But I think it should be like this.

offset_normalizer = []
for shape, num_points in zip(value_spatial_shapes, self.num_points_list):
    repeated_shapes = [shape] * num_points
    offset_normalizer.extend(repeated_shapes)
offset_normalizer = torch.tensor(offset_normalizer).flip(1).to(reference_points.device) # sum(self.num_points_list), 2
sampling_locations = reference_points[:, :, None, :, :] + sampling_offsets / offset_normalizer

@Anchor1566
Copy link
Author

Yes, you're right. The rtdetrv2 code didn't consider the case of reference_points.shape[-1] == 2.

But I think it should be like this.

offset_normalizer = []
for shape, num_points in zip(value_spatial_shapes, self.num_points_list):
    repeated_shapes = [shape] * num_points
    offset_normalizer.extend(repeated_shapes)
offset_normalizer = torch.tensor(offset_normalizer).flip(1).to(reference_points.device) # sum(self.num_points_list), 2
sampling_locations = reference_points[:, :, None, :, :] + sampling_offsets / offset_normalizer

Thank you for your reply 😊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants