You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While studying the implementation, I encountered a potential mistake in the forward method of MSDeformableAttention, in the branch where reference_points.shape[-1] == 2. I suspect there might be a dimension misalignment during the computation of sampling_locations, in the following line:
1.The shape of reference_points after reshaping is (bs, Len_q, 1, self.num_levels, 1, 2).
2.The shape of sampling_offsets is (bs, Len_q, self.num_heads, sum(self.num_points_list), 2).
3.The shape of offset_normalizer is (1, 1, 1, self.num_levels, 1, 2).
Given these shapes, it seems that the reference_points and sampling_offsets / offset_normalizer tensors may not align correctly due to the lack of self.num_heads and sum(self.num_points_list) dimensions in reference_points.( they are misalignment in self.num_levels and self.num_levels in this case)
My Attempted Fix
To resolve this, I modified the code lightly to ensure shape compatibility by expanding reference_points:
def forward(self,
query: torch.Tensor,
reference_points: torch.Tensor,
value: torch.Tensor,
value_spatial_shapes: List[int],
value_mask: torch.Tensor=None):
"""
Args:
query (Tensor): [bs, query_length, C]
reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area
value (Tensor): [bs, value_length, C]
value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
Returns:
output (Tensor): [bs, Length_{query}, C]
"""
bs, Len_q = query.shape[:2]
Len_v = value.shape[1]
value = self.value_proj(value)
if value_mask is not None:
value = value * value_mask.to(value.dtype).unsqueeze(-1)
value = value.reshape(bs, Len_v, self.num_heads, self.head_dim)
sampling_offsets: torch.Tensor = self.sampling_offsets(query)
sampling_offsets = sampling_offsets.reshape(bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).reshape(bs, Len_q, self.num_heads, sum(self.num_points_list))
attention_weights = F.softmax(attention_weights, dim=-1).reshape(bs, Len_q, self.num_heads, sum(self.num_points_list))
if reference_points.shape[-1] == 2:
offset_normalizer = torch.tensor(value_spatial_shapes).flip([1]).to(reference_points.device)
offset_normalizer = offset_normalizer.reshape(1, 1, 1, self.num_levels, 1, 2)
sampling_locations = reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2) \
+ sampling_offsets / offset_normalizer
sampling_locations = sampling_locations.reshape(bs, Len_q, self.num_heads, self.num_levels * self.num_points, 2)
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points * reference_points[:, :, None, :, None, 2:] * 0.5
sampling_locations = sampling_locations.reshape(bs, Len_q, self.num_heads, self.num_levels * self.num_points, 2)
else:
raise ValueError(
"Last dim of reference_points must be 2 or 4, but get {} instead.".
format(reference_points.shape[-1]))
output = self.ms_deformable_attn_core(value, value_spatial_shapes, sampling_locations, attention_weights, self.num_points_list)
output = self.output_proj(output)
return output
With this modification, the broadcasting appears to work as expected, but I’m unsure if this fully aligns with the intended functionality.
Thank you again for your amazing work!
The text was updated successfully, but these errors were encountered:
First of all, thank you for your excellent work!
While studying the implementation, I encountered a potential mistake in the forward method of MSDeformableAttention, in the branch where reference_points.shape[-1] == 2. I suspect there might be a dimension misalignment during the computation of sampling_locations, in the following line:
Here’s my reasoning:
1.The shape of reference_points after reshaping is (bs, Len_q, 1, self.num_levels, 1, 2).
2.The shape of sampling_offsets is (bs, Len_q, self.num_heads, sum(self.num_points_list), 2).
3.The shape of offset_normalizer is (1, 1, 1, self.num_levels, 1, 2).
Given these shapes, it seems that the reference_points and sampling_offsets / offset_normalizer tensors may not align correctly due to the lack of self.num_heads and sum(self.num_points_list) dimensions in reference_points.( they are misalignment in self.num_levels and self.num_levels in this case)
My Attempted Fix
To resolve this, I modified the code lightly to ensure shape compatibility by expanding reference_points:
With this modification, the broadcasting appears to work as expected, but I’m unsure if this fully aligns with the intended functionality.
Thank you again for your amazing work!
The text was updated successfully, but these errors were encountered: