Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from NVIDIA:main #46

Merged
merged 2 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 33 additions & 27 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,30 +148,30 @@ def make_mask(
segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5]
segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
"""
# segment masks
inv_mask = make_attention_mask(
segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0))
)

if segment_pos_q is None:
segment_pos_q = jnp.broadcast_to(
jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
)
if segment_pos_kv is None:
segment_pos_kv = jnp.broadcast_to(
jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
)

# causal mask
if attn_mask_type.is_causal():
if segment_pos_q is None:
segment_pos_q = jnp.broadcast_to(
jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape
)
if segment_pos_kv is None:
segment_pos_kv = jnp.broadcast_to(
jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape
)
inv_causal_mask = make_attention_mask(
segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y)
)
inv_mask = combine_masks(inv_causal_mask, inv_mask)

if window_size is not None:
max_seqlen_q = inv_mask.shape[-2]
max_seqlen_kv = inv_mask.shape[-1]
inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type)
inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape)
inv_mask = combine_masks(inv_mask, inv_swa_mask)

# sliding window mask
inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_)
inv_mask = combine_masks(inv_mask, inv_swa_mask)
mask = jnp.logical_not(inv_mask)
return mask

Expand Down Expand Up @@ -314,13 +314,6 @@ def _get_max_segments_per_sequence(self):
return self.num_segments_per_seq + 1

def _check_configs(self):
# TODO(rewang): Fix THD + PADDING_CAUSAL + SWA reference
if (
self.qkv_layout.is_thd()
and self.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK
and self.window_size is not None
):
pytest.skip("THD + PADDING_CAUSAL + SWA reference is not implemented.")
# TODO(rewang): probably adds this in is_fused_attn_available
if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
pytest.skip("THD format requires padding masks.")
Expand Down Expand Up @@ -432,7 +425,12 @@ def gen_valid(bs, max_seqlen, pad_ratio):
return tokens, jnp.logical_not(tokens)

def generate_random_segment_ids(
batch_size, sequence_length, num_segments, seed, with_segment_pad=True
batch_size,
sequence_length,
num_segments,
seed,
with_segment_pad=True,
min_segment_len=None,
):
rng = np.random.default_rng(seed=seed)
# [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad
Expand All @@ -448,15 +446,20 @@ def generate_random_segment_ids(
current_pos = 0
segment_id = 1

for _ in range(num_segments):
segment_size = rng.integers(1, max_segment_size + 1)
for seg_id in range(num_segments):
# min_segment_len is to force kv_len >= q_len because cuDNN kernels failed
# TODO(rewang): Remove this constrain after cuDNN supports
min_segment_size = 1
if min_segment_len is not None:
min_segment_size = min_segment_len[i][seg_id]
segment_size = rng.integers(min_segment_size, max_segment_size + 1)
if current_pos + segment_size > sequence_length:
break
segment_end = current_pos + segment_size
segment_ids[i, current_pos:segment_end] = segment_id
segment_pos[i, current_pos:segment_end] = np.arange(segment_size)
if with_segment_pad:
num_valid = rng.integers(1, segment_size + 1)
num_valid = rng.integers(min_segment_size, segment_size + 1)
segment_pad[i, current_pos + num_valid : segment_end] = 1
current_pos = segment_end
segment_id += 1
Expand All @@ -473,18 +476,21 @@ def generate_random_segment_ids(
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
if self.qkv_layout == QKVLayout.T3HD:
self.segment_ids_kv = self.segment_ids_q
self.segment_pos_kv = self.segment_pos_q
self.pad_kv = self.pad_q
else:
# Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
min_segment_len = None if self.window_size is None else self.seqlens_q
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
self.num_segments_per_seq,
seed=2024,
min_segment_len=min_segment_len,
)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
else:
self.num_segments_per_seq = 1
Expand Down
10 changes: 5 additions & 5 deletions tests/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,14 +919,14 @@ def apply_swa_mask(
"""Apply the sliding window mask to a given mask"""
_attn_mask_type = canonicalize_attn_mask_type(attn_mask_type)
assert _attn_mask_type is not None
batch = original_mask.shape[0]
max_seqlen_q = original_mask.shape[-2]
max_seqlen_kv = original_mask.shape[-1]
swa_mask = make_swa_mask(
max_seqlen_q, max_seqlen_kv, window_size, _attn_mask_type, dtype=original_mask.dtype
)
pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
swa_mask = make_swa_mask(pos_q, pos_kv, window_size, original_mask.dtype)
# In swa_mask and original_mask 0 is masked out
swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape)
new_mask = jnp.where(original_mask == 1, swa_mask_bcast, original_mask)
new_mask = jnp.where(original_mask == 1, swa_mask, original_mask)
return new_mask


Expand Down
79 changes: 32 additions & 47 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,59 +147,44 @@ class CPStrategy(Enum):


def make_swa_mask(
max_seqlen_q: int,
max_seqlen_kv: int,
segment_pos_q: jnp.ndarray,
segment_pos_kv: jnp.ndarray,
window_size: Optional[Tuple[int, int]] = None,
attn_mask_type: AttnMaskType = AttnMaskType.NO_MASK,
dtype: jax.typing.DTypeLike = jnp.float32,
):
"""
Generate sliding window mask. `True` or `1` means keep the element.

For `CAUSAL_BOTTOM_RIGHT_MASK` and `PADDING_CAUSAL_BOTTOM_RIGHT_MASK` mask type,
the sliding window diagonal is aligned to the bottom right corner, and for other
mask types, the top left corner.

Parameters
----------
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
window_size: Optional[Tuple[int, int]] = None
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Negative number in window size means infinity window.
`None` means no sliding window.
attn_mask_type: AttnMaskType, default = AttnMaskType.NO_MASK
dtype: jax.typing.DTypeLike, default=jnp.float32
The mask data type.
Returns
----------
swa_mask: jax.numpy.tensor
Matrix with shape [max_seqlen_q, max_seqlen_kv]. Elements with value 1 are the positions
that will get attention, value 0 are the masked out positions.
Generate a sliding window mask (1 = attend, 0 = masked).

Args:
segment_pos_q (jnp.ndarray):
Query positions within each segment. For example, a batch with segment_ids =
[[1, 1, 1, 2, 2, 2, 2, 2]] yields segment_pos =
[[0, 1, 2, 0, 1, 2, 3, 4]].
segment_pos_kv (jnp.ndarray):
Key/value positions within each segment.
window_size (Optional[Tuple[int, int]], optional):
Sliding window size for local attention, where query at position i attends to keys
in [i - window_size[0], i + window_size[1]] inclusive. A negative number means an
infinite window; None means no sliding window.
Defaults to None.
dtype (jax.typing.DTypeLike, optional):
Mask data type. Defaults to jnp.float32.

Returns:
jnp.ndarray:
The mask with shape [b, 1, max_seqlen_q, max_seqlen_kv].
"""
swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype)
if window_size is None:
return swa_mask
left_window, right_window = window_size
if attn_mask_type.is_bottom_right():
if left_window < 0:
left_window = max_seqlen_kv
if right_window < 0:
right_window = max_seqlen_kv
bottom_right_shift = max_seqlen_kv - max_seqlen_q
swa_mask = jnp.triu(swa_mask, k=-left_window + bottom_right_shift)
swa_mask = jnp.tril(swa_mask, k=right_window + bottom_right_shift)
if window_size is not None:
left_window, right_window = window_size
else:
if left_window < 0:
left_window = max_seqlen_q
if right_window < 0:
right_window = max_seqlen_q
swa_mask = jnp.triu(swa_mask, k=-left_window)
swa_mask = jnp.tril(swa_mask, k=right_window)
return swa_mask
left_window = right_window = jnp.inf
left_window = jnp.inf if left_window < 0 else left_window
right_window = jnp.inf if right_window < 0 else right_window
pos_q = jnp.expand_dims(segment_pos_q, axis=-1)
pos_kv = jnp.expand_dims(segment_pos_kv, axis=-2)
inv_swa_mask = (pos_kv >= pos_q - left_window) & (pos_kv <= pos_q + right_window)
inv_swa_mask = jnp.expand_dims(inv_swa_mask, axis=-3)
return inv_swa_mask.astype(dtype)


def canonicalize_attn_mask_type(attn_mask_type: str):
Expand Down
17 changes: 10 additions & 7 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,18 @@ def __call__(
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
attn_weights += bias

def apply_swa_mask(attn_mask_type: AttnMaskType, original_mask: Array) -> Array:
def apply_swa_mask(original_mask: Array) -> Array:
"""Apply the sliding window mask to a given mask"""
batch = original_mask.shape[0]
max_seqlen_q = original_mask.shape[-2]
max_seqlen_kv = original_mask.shape[-1]
swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, self.window_size, attn_mask_type)
# In swa_mask 0 is masked out, in original_mask 1 is masked out
swa_mask = 1 - swa_mask.astype(original_mask.dtype)
swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape)
new_mask = jnp.where(original_mask == 0, swa_mask_bcast, original_mask)
# TODO(rewang): Support THD format pos
pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q))
pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv))
# In inv_swa_mask 0 is masked out, in original_mask 1 is masked out
inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype)
swa_mask = 1 - inv_swa_mask
new_mask = jnp.where(original_mask == 0, swa_mask, original_mask)
return new_mask

def convert_to_softmax_type(attn_mask_type, mask):
Expand All @@ -213,7 +216,7 @@ def convert_to_softmax_type(attn_mask_type, mask):
if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None:
mask = None
if mask is not None:
mask = apply_swa_mask(attn_mask_type, mask)
mask = apply_swa_mask(mask)
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]:
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def forward(
ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None,
)
if not is_grad_enabled:
if not is_grad_enabled and not return_layernorm_output:
clear_tensor_data(ln_out_total)

if bias_gelu_nvfusion:
Expand Down
Loading