From c65144b9d4673e8d6bebce21babd8ce0040eccc9 Mon Sep 17 00:00:00 2001 From: "Yuhuai(Tony) Wu" Date: Thu, 22 Jun 2023 13:04:44 -0700 Subject: [PATCH 1/4] A faster flash attention bwd implementation - Decompose the bwd kernel into two kernels, one for dq and one for dk,dv. - Extra parallelism over the sequence length axis. - On a benchmark, it is 4X faster compared to the previous implementation. 2X faster than XLA bwd pass. --- jax_triton/pallas/ops/attention.py | 159 +++++++++++++++++++++-------- 1 file changed, 115 insertions(+), 44 deletions(-) diff --git a/jax_triton/pallas/ops/attention.py b/jax_triton/pallas/ops/attention.py index e897ffba..9233c572 100644 --- a/jax_triton/pallas/ops/attention.py +++ b/jax_triton/pallas/ops/attention.py @@ -99,7 +99,7 @@ def body(start_k, carry): "interpret", "debug"]) def mha(q, k, v, sm_scale: float = 1.0, - causal: bool = False, + causal: bool = True, block_q: int = 128, block_k: int = 128, backward_pass_impl: str = "triton", @@ -232,27 +232,77 @@ def _preprocess_backward(out, do, l, block_q: int, name="mha_preprocess_backward")(out, do, l) return do_scaled, delta -def mha_backward_kernel( +def mha_backward_kernel_dq( # Inputs q_ref, k_ref, v_ref, out_ref, do_scaled_ref, l_ref, m_ref, delta_ref, _, # Outputs - dq_ref, dk_ref, dv_ref, + dq_ref, *, sm_scale: float, causal: bool, block_q: int, block_d: int, block_k: int ): del out_ref, l_ref # Not needed seq_len = q_ref.shape[0] - def outer_loop(start_k, _): + start_q = pl.program_id(0) + q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None))) + span_q = start_q * block_q + jnp.arange(block_q) + m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),)) + do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None))) + di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),)) + dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q), + slice(None)), eviction_policy="evict_last") + + def inner_loop(start_k, carry): + dq = carry + span_k = start_k * block_k + jnp.arange(block_k) + k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None))) + v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None))) + + qk = pl.dot(q, k.T) + qk = qk.astype(q_ref.dtype) + qk = qk.astype(jnp.float32) + if sm_scale != 1.0: + qk *= sm_scale + if causal: + qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf')) + p = jnp.exp(qk - m[:, None]) + dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] + dp = dp + pl.dot(do, v.T) + ds = p * dp + if sm_scale != 1.0: + ds = ds * sm_scale + dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) + return dq + if causal: + upper_bound = lax.div(start_q * block_q, block_k) + 1 + else: + upper_bound = jt.cdiv(seq_len, block_k) + dq = lax.fori_loop(0, upper_bound, inner_loop, dq) + pl.store(dq_ref, (pl.ds(start_q * block_q, block_q), + slice(None)), dq, eviction_policy="evict_last") + + +def mha_backward_kernel_dkv( + # Inputs + q_ref, k_ref, v_ref, out_ref, do_scaled_ref, + l_ref, m_ref, delta_ref, + # Outputs + dk_ref, dv_ref, + *, sm_scale: float, causal: bool, + block_q: int, block_d: int, block_k: int +): + del out_ref, l_ref # Not needed + seq_len = q_ref.shape[0] + start_k = pl.program_id(0) - dv = jnp.zeros([block_k, block_d], dtype=jnp.float32) - dk = jnp.zeros([block_k, block_d], dtype=jnp.float32) - k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None))) - v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None))) - span_k = start_k * block_k + jnp.arange(block_k) + dv = jnp.zeros([block_k, block_d], dtype=jnp.float32) + dk = jnp.zeros([block_k, block_d], dtype=jnp.float32) + k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None))) + v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None))) + span_k = start_k * block_k + jnp.arange(block_k) - def inner_loop(start_q, carry): + def inner_loop(start_q, carry): dv, dk = carry q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None))) qk = pl.dot(q, k.T) @@ -274,23 +324,18 @@ def inner_loop(start_q, carry): if sm_scale != 1.0: ds = ds * sm_scale dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q) - dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q), - slice(None)), eviction_policy="evict_last") - dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) - pl.store(dq_ref, (pl.ds(start_q * block_q, block_q), - slice(None)), dq, eviction_policy="evict_last") return dv, dk - if causal: + if causal: lower_bound = lax.div(start_k * block_k, block_q) - else: + else: lower_bound = 0 - dv, dk = lax.fori_loop(lower_bound, jt.cdiv(seq_len, block_q), inner_loop, + dv, dk = lax.fori_loop(lower_bound, jt.cdiv(seq_len, block_q), inner_loop, (dv, dk)) - pl.store(dv_ref, (pl.ds(start_k * block_k, block_k), + pl.store(dv_ref, (pl.ds(start_k * block_k, block_k), slice(None)), dv.astype(dv_ref.dtype)) - pl.store(dk_ref, (pl.ds(start_k * block_k, block_k), + pl.store(dk_ref, (pl.ds(start_k * block_k, block_k), slice(None)), dk.astype(dk_ref.dtype)) - lax.fori_loop(0, jt.cdiv(seq_len, block_k), outer_loop, None) + def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, backward_pass_impl: str, num_warps: Optional[int], @@ -310,35 +355,30 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, elif backward_pass_impl == "triton": # We accumulate into dq so we need to initialize it to zeros. dq = jnp.zeros(q.shape, jnp.float32) - out_shapes = [ - jax.ShapeDtypeStruct(dq.shape, dq.dtype), - jax.ShapeDtypeStruct(k.shape, k.dtype), - jax.ShapeDtypeStruct(v.shape, v.dtype), - ] + out_shapes_q = jax.ShapeDtypeStruct(dq.shape, dq.dtype) - grid = (batch_size, num_heads) + grid_q = (jt.cdiv(seq_len, block_q), batch_size, num_heads) + # grid_q = (batch_size, num_heads) # TODO(sharadmv): figure out why num_warps=8 doesn't work! num_warps = 4 - dq, dk, dv = pl.pallas_call( - functools.partial(mha_backward_kernel, block_q=block_q, block_d=head_dim, + dq = pl.pallas_call( + functools.partial(mha_backward_kernel_dq, block_q=block_q, block_d=head_dim, block_k=block_k, sm_scale=sm_scale, causal=causal), - grid=grid, - out_shape=out_shapes, + grid=grid_q, + out_shape=out_shapes_q, in_specs=[ - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), ], out_specs=[ - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), ], name="mha_backward", debug=debug, @@ -346,6 +386,37 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, num_warps=num_warps, num_stages=1, input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq) + + grid_kv = (jt.cdiv(seq_len, block_k), batch_size, num_heads) + out_shapes_kv = [ + jax.ShapeDtypeStruct(k.shape, k.dtype), + jax.ShapeDtypeStruct(v.shape, v.dtype), + ] + dk, dv = pl.pallas_call( + functools.partial(mha_backward_kernel_dkv, block_q=block_q, block_d=head_dim, + block_k=block_k, sm_scale=sm_scale, causal=causal), + grid=grid_kv, + out_shape=out_shapes_kv, + in_specs=[ + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + ], + out_specs=[ + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + ], + name="mha_backward", + debug=debug, + interpret=interpret, + num_warps=num_warps, + num_stages=1)(q, k, v, out, do_scaled, l, m, delta) + else: raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}") return dq.astype(q.dtype), dk, dv @@ -353,7 +424,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, @functools.partial(jax.jit, static_argnames=['sm_scale', 'causal']) -def mha_reference(q, k, v, sm_scale=1.0, causal: bool = False): +def mha_reference(q, k, v, sm_scale=1.0, causal: bool = True): q_seq_len = q.shape[1] kv_seq_len = k.shape[1] logits = jnp.einsum('bqhc,bkhc->bhqk', q, k).astype(jnp.float32) From 7121d59a880ea011121b7e1fb362140b203c3c89 Mon Sep 17 00:00:00 2001 From: "Yuhuai(Tony) Wu" Date: Thu, 22 Jun 2023 21:25:00 +0000 Subject: [PATCH 2/4] add back previous attention kernel as an option --- jax_triton/pallas/ops/attention.py | 103 ++++++++++++++++++++++++++++- 1 file changed, 100 insertions(+), 3 deletions(-) diff --git a/jax_triton/pallas/ops/attention.py b/jax_triton/pallas/ops/attention.py index 9233c572..4929e81c 100644 --- a/jax_triton/pallas/ops/attention.py +++ b/jax_triton/pallas/ops/attention.py @@ -232,6 +232,66 @@ def _preprocess_backward(out, do, l, block_q: int, name="mha_preprocess_backward")(out, do, l) return do_scaled, delta +def mha_backward_kernel( + # Inputs + q_ref, k_ref, v_ref, out_ref, do_scaled_ref, + l_ref, m_ref, delta_ref, _, + # Outputs + dq_ref, dk_ref, dv_ref, + *, sm_scale: float, causal: bool, + block_q: int, block_d: int, block_k: int +): + del out_ref, l_ref # Not needed + seq_len = q_ref.shape[0] + + def outer_loop(start_k, _): + + dv = jnp.zeros([block_k, block_d], dtype=jnp.float32) + dk = jnp.zeros([block_k, block_d], dtype=jnp.float32) + k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None))) + v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None))) + span_k = start_k * block_k + jnp.arange(block_k) + + def inner_loop(start_q, carry): + dv, dk = carry + q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None))) + qk = pl.dot(q, k.T) + qk = qk.astype(q_ref.dtype) + qk = qk.astype(jnp.float32) + if sm_scale != 1.0: + qk *= sm_scale + if causal: + span_q = start_q * block_q + jnp.arange(block_q) + qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf')) + m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),)) + p = jnp.exp(qk - m[:, None]) + do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None))) + dv = dv + pl.dot(p.astype(do.dtype).T, do) + di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),)) + dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] + dp = dp + pl.dot(do, v.T) + ds = p * dp + if sm_scale != 1.0: + ds = ds * sm_scale + dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q) + dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q), + slice(None)), eviction_policy="evict_last") + dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) + pl.store(dq_ref, (pl.ds(start_q * block_q, block_q), + slice(None)), dq, eviction_policy="evict_last") + return dv, dk + if causal: + lower_bound = lax.div(start_k * block_k, block_q) + else: + lower_bound = 0 + dv, dk = lax.fori_loop(lower_bound, jt.cdiv(seq_len, block_q), inner_loop, + (dv, dk)) + pl.store(dv_ref, (pl.ds(start_k * block_k, block_k), + slice(None)), dv.astype(dv_ref.dtype)) + pl.store(dk_ref, (pl.ds(start_k * block_k, block_k), + slice(None)), dk.astype(dk_ref.dtype)) + lax.fori_loop(0, jt.cdiv(seq_len, block_k), outer_loop, None) + def mha_backward_kernel_dq( # Inputs q_ref, k_ref, v_ref, out_ref, do_scaled_ref, @@ -282,7 +342,6 @@ def inner_loop(start_k, carry): pl.store(dq_ref, (pl.ds(start_q * block_q, block_q), slice(None)), dq, eviction_policy="evict_last") - def mha_backward_kernel_dkv( # Inputs q_ref, k_ref, v_ref, out_ref, do_scaled_ref, @@ -336,7 +395,6 @@ def inner_loop(start_q, carry): pl.store(dk_ref, (pl.ds(start_k * block_k, block_k), slice(None)), dk.astype(dk_ref.dtype)) - def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, backward_pass_impl: str, num_warps: Optional[int], num_stages: int, grid: Any, interpret: bool, @@ -353,6 +411,45 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, return jax.vjp(functools.partial(mha_reference, sm_scale=sm_scale, causal=causal), q, k, v)[1](do) elif backward_pass_impl == "triton": + # We accumulate into dq so we need to initialize it to zeros. + dq = jnp.zeros(q.shape, jnp.float32) + out_shapes = [ + jax.ShapeDtypeStruct(dq.shape, dq.dtype), + jax.ShapeDtypeStruct(k.shape, k.dtype), + jax.ShapeDtypeStruct(v.shape, v.dtype), + ] + + grid = (batch_size, num_heads) + # TODO(sharadmv): figure out why num_warps=8 doesn't work! + num_warps = 4 + dq, dk, dv = pl.pallas_call( + functools.partial(mha_backward_kernel, block_q=block_q, block_d=head_dim, + block_k=block_k, sm_scale=sm_scale, causal=causal), + grid=grid, + out_shape=out_shapes, + in_specs=[ + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + ], + out_specs=[ + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + ], + name="mha_backward", + debug=debug, + interpret=interpret, + num_warps=num_warps, + num_stages=1, + input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq) + elif backward_pass_impl == "triton_split": # We accumulate into dq so we need to initialize it to zeros. dq = jnp.zeros(q.shape, jnp.float32) out_shapes_q = jax.ShapeDtypeStruct(dq.shape, dq.dtype) @@ -433,4 +530,4 @@ def mha_reference(q, k, v, sm_scale=1.0, causal: bool = True): mask = jnp.broadcast_to(mask, logits.shape) logits = jnp.where(mask, logits, float('-inf')) weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) - return jnp.einsum('bhqk,bkhc->bqhc', weights, v) + return jnp.einsum('bhqk,bkhc->bqhc', weights, v) \ No newline at end of file From 8156ad0749287ac86e7782583011d0dd26c0b32a Mon Sep 17 00:00:00 2001 From: "Yuhuai(Tony) Wu" Date: Fri, 23 Jun 2023 00:00:33 +0000 Subject: [PATCH 3/4] fix comments by sharad --- jax_triton/pallas/ops/attention.py | 184 ++++++++++++++--------------- 1 file changed, 90 insertions(+), 94 deletions(-) diff --git a/jax_triton/pallas/ops/attention.py b/jax_triton/pallas/ops/attention.py index 4929e81c..8b8217a5 100644 --- a/jax_triton/pallas/ops/attention.py +++ b/jax_triton/pallas/ops/attention.py @@ -99,7 +99,7 @@ def body(start_k, carry): "interpret", "debug"]) def mha(q, k, v, sm_scale: float = 1.0, - causal: bool = True, + causal: bool = False, block_q: int = 128, block_k: int = 128, backward_pass_impl: str = "triton", @@ -190,7 +190,7 @@ def _mha_forward(q, k, v, sm_scale: float, causal: bool, block_q: int, def _preprocess_backward_kernel(out_ref, dout_ref, l_ref, new_dout_ref, delta_ref, *, block_q: int): - pid_m = pl.program_id(0) + pid_m = pl.program_id(2) off_m = pl.ds(pid_m * block_q, block_q) # load @@ -214,15 +214,15 @@ def _preprocess_backward(out, do, l, block_q: int, ] do_scaled, delta = pl.pallas_call( functools.partial(_preprocess_backward_kernel, block_q=block_q), - grid=(jt.cdiv(seq_len, block_q), batch_size, num_heads), + grid=(batch_size, num_heads, jt.cdiv(seq_len, block_q)), in_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), ], out_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), ], num_warps=4, num_stages=3, @@ -232,6 +232,7 @@ def _preprocess_backward(out, do, l, block_q: int, name="mha_preprocess_backward")(out, do, l) return do_scaled, delta + def mha_backward_kernel( # Inputs q_ref, k_ref, v_ref, out_ref, do_scaled_ref, @@ -292,10 +293,11 @@ def inner_loop(start_q, carry): slice(None)), dk.astype(dk_ref.dtype)) lax.fori_loop(0, jt.cdiv(seq_len, block_k), outer_loop, None) + def mha_backward_kernel_dq( # Inputs q_ref, k_ref, v_ref, out_ref, do_scaled_ref, - l_ref, m_ref, delta_ref, _, + l_ref, m_ref, delta_ref, # Outputs dq_ref, *, sm_scale: float, causal: bool, @@ -303,45 +305,43 @@ def mha_backward_kernel_dq( ): del out_ref, l_ref # Not needed seq_len = q_ref.shape[0] - - start_q = pl.program_id(0) + start_q = pl.program_id(2) q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None))) span_q = start_q * block_q + jnp.arange(block_q) m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),)) do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None))) di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),)) - dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q), - slice(None)), eviction_policy="evict_last") + dq = jnp.zeros([block_q, block_d], dtype=jnp.float32) - def inner_loop(start_k, carry): - dq = carry - span_k = start_k * block_k + jnp.arange(block_k) - k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None))) - v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None))) + def inner_loop(start_k, dq): + span_k = start_k * block_k + jnp.arange(block_k) + k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None))) + v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None))) - qk = pl.dot(q, k.T) - qk = qk.astype(q_ref.dtype) - qk = qk.astype(jnp.float32) - if sm_scale != 1.0: - qk *= sm_scale - if causal: - qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf')) - p = jnp.exp(qk - m[:, None]) - dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] - dp = dp + pl.dot(do, v.T) - ds = p * dp - if sm_scale != 1.0: - ds = ds * sm_scale - dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) - return dq + qk = pl.dot(q, k.T) + qk = qk.astype(q_ref.dtype) + qk = qk.astype(jnp.float32) + if sm_scale != 1.0: + qk *= sm_scale + if causal: + qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf')) + p = jnp.exp(qk - m[:, None]) + dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] + dp = dp + pl.dot(do, v.T) + ds = p * dp + if sm_scale != 1.0: + ds = ds * sm_scale + dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) + return dq if causal: - upper_bound = lax.div(start_q * block_q, block_k) + 1 + upper_bound = lax.div(start_q * block_q, block_k) + 1 else: - upper_bound = jt.cdiv(seq_len, block_k) + upper_bound = jt.cdiv(seq_len, block_k) dq = lax.fori_loop(0, upper_bound, inner_loop, dq) pl.store(dq_ref, (pl.ds(start_q * block_q, block_q), slice(None)), dq, eviction_policy="evict_last") + def mha_backward_kernel_dkv( # Inputs q_ref, k_ref, v_ref, out_ref, do_scaled_ref, @@ -353,7 +353,7 @@ def mha_backward_kernel_dkv( ): del out_ref, l_ref # Not needed seq_len = q_ref.shape[0] - start_k = pl.program_id(0) + start_k = pl.program_id(2) dv = jnp.zeros([block_k, block_d], dtype=jnp.float32) dk = jnp.zeros([block_k, block_d], dtype=jnp.float32) @@ -362,32 +362,32 @@ def mha_backward_kernel_dkv( span_k = start_k * block_k + jnp.arange(block_k) def inner_loop(start_q, carry): - dv, dk = carry - q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None))) - qk = pl.dot(q, k.T) - qk = qk.astype(q_ref.dtype) - qk = qk.astype(jnp.float32) - if sm_scale != 1.0: - qk *= sm_scale - if causal: - span_q = start_q * block_q + jnp.arange(block_q) - qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf')) - m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),)) - p = jnp.exp(qk - m[:, None]) - do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None))) - dv = dv + pl.dot(p.astype(do.dtype).T, do) - di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),)) - dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] - dp = dp + pl.dot(do, v.T) - ds = p * dp - if sm_scale != 1.0: - ds = ds * sm_scale - dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q) - return dv, dk + dv, dk = carry + q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None))) + qk = pl.dot(q, k.T) + qk = qk.astype(q_ref.dtype) + qk = qk.astype(jnp.float32) + if sm_scale != 1.0: + qk *= sm_scale + if causal: + span_q = start_q * block_q + jnp.arange(block_q) + qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf')) + m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),)) + p = jnp.exp(qk - m[:, None]) + do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None))) + dv = dv + pl.dot(p.astype(do.dtype).T, do) + di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),)) + dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] + dp = dp + pl.dot(do, v.T) + ds = p * dp + if sm_scale != 1.0: + ds = ds * sm_scale + dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q) + return dv, dk if causal: - lower_bound = lax.div(start_k * block_k, block_q) + lower_bound = lax.div(start_k * block_k, block_q) else: - lower_bound = 0 + lower_bound = 0 dv, dk = lax.fori_loop(lower_bound, jt.cdiv(seq_len, block_q), inner_loop, (dv, dk)) pl.store(dv_ref, (pl.ds(start_k * block_k, block_k), @@ -395,6 +395,7 @@ def inner_loop(start_q, carry): pl.store(dk_ref, (pl.ds(start_k * block_k, block_k), slice(None)), dk.astype(dk_ref.dtype)) + def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, backward_pass_impl: str, num_warps: Optional[int], num_stages: int, grid: Any, interpret: bool, @@ -451,40 +452,36 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq) elif backward_pass_impl == "triton_split": # We accumulate into dq so we need to initialize it to zeros. - dq = jnp.zeros(q.shape, jnp.float32) - out_shapes_q = jax.ShapeDtypeStruct(dq.shape, dq.dtype) + out_shapes_q = jax.ShapeDtypeStruct(q.shape, jnp.float32) - grid_q = (jt.cdiv(seq_len, block_q), batch_size, num_heads) - # grid_q = (batch_size, num_heads) + grid_q = (batch_size, num_heads, jt.cdiv(seq_len, block_q)) # TODO(sharadmv): figure out why num_warps=8 doesn't work! - num_warps = 4 + num_warps = 8 dq = pl.pallas_call( functools.partial(mha_backward_kernel_dq, block_q=block_q, block_d=head_dim, block_k=block_k, sm_scale=sm_scale, causal=causal), grid=grid_q, out_shape=out_shapes_q, in_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), ], out_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), ], - name="mha_backward", + name="mha_backward_q", debug=debug, interpret=interpret, num_warps=num_warps, - num_stages=1, - input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq) + num_stages=2)(q, k, v, out, do_scaled, l, m, delta) - grid_kv = (jt.cdiv(seq_len, block_k), batch_size, num_heads) + grid_kv = (batch_size, num_heads, jt.cdiv(seq_len, block_k)) out_shapes_kv = [ jax.ShapeDtypeStruct(k.shape, k.dtype), jax.ShapeDtypeStruct(v.shape, v.dtype), @@ -495,25 +492,24 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, grid=grid_kv, out_shape=out_shapes_kv, in_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k, _: (j, k, 0), (None, None, seq_len)), ], out_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)), ], - name="mha_backward", + name="mha_backward_kv", debug=debug, interpret=interpret, num_warps=num_warps, - num_stages=1)(q, k, v, out, do_scaled, l, m, delta) - + num_stages=2)(q, k, v, out, do_scaled, l, m, delta) else: raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}") return dq.astype(q.dtype), dk, dv @@ -521,7 +517,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, @functools.partial(jax.jit, static_argnames=['sm_scale', 'causal']) -def mha_reference(q, k, v, sm_scale=1.0, causal: bool = True): +def mha_reference(q, k, v, sm_scale=1.0, causal: bool = False): q_seq_len = q.shape[1] kv_seq_len = k.shape[1] logits = jnp.einsum('bqhc,bkhc->bhqk', q, k).astype(jnp.float32) @@ -530,4 +526,4 @@ def mha_reference(q, k, v, sm_scale=1.0, causal: bool = True): mask = jnp.broadcast_to(mask, logits.shape) logits = jnp.where(mask, logits, float('-inf')) weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) - return jnp.einsum('bhqk,bkhc->bqhc', weights, v) \ No newline at end of file + return jnp.einsum('bhqk,bkhc->bqhc', weights, v) From bdbddc916e5e01afed6846a82035c069efb95962 Mon Sep 17 00:00:00 2001 From: "Yuhuai(Tony) Wu" Date: Fri, 23 Jun 2023 00:02:06 +0000 Subject: [PATCH 4/4] delete a comment --- jax_triton/pallas/ops/attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax_triton/pallas/ops/attention.py b/jax_triton/pallas/ops/attention.py index 8b8217a5..0b7df6c6 100644 --- a/jax_triton/pallas/ops/attention.py +++ b/jax_triton/pallas/ops/attention.py @@ -455,7 +455,6 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, out_shapes_q = jax.ShapeDtypeStruct(q.shape, jnp.float32) grid_q = (batch_size, num_heads, jt.cdiv(seq_len, block_q)) - # TODO(sharadmv): figure out why num_warps=8 doesn't work! num_warps = 8 dq = pl.pallas_call( functools.partial(mha_backward_kernel_dq, block_q=block_q, block_d=head_dim,