Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A faster flash attention bwd implementation #177

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 115 additions & 44 deletions jax_triton/pallas/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def body(start_k, carry):
"interpret", "debug"])
def mha(q, k, v,
sm_scale: float = 1.0,
causal: bool = False,
causal: bool = True,
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
block_q: int = 128,
block_k: int = 128,
backward_pass_impl: str = "triton",
Expand Down Expand Up @@ -232,27 +232,77 @@ def _preprocess_backward(out, do, l, block_q: int,
name="mha_preprocess_backward")(out, do, l)
return do_scaled, delta

def mha_backward_kernel(
def mha_backward_kernel_dq(
# Inputs
q_ref, k_ref, v_ref, out_ref, do_scaled_ref,
l_ref, m_ref, delta_ref, _,
# Outputs
dq_ref, dk_ref, dv_ref,
dq_ref,
*, sm_scale: float, causal: bool,
block_q: int, block_d: int, block_k: int
):
del out_ref, l_ref # Not needed
seq_len = q_ref.shape[0]

def outer_loop(start_k, _):
start_q = pl.program_id(0)
q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
span_q = start_q * block_q + jnp.arange(block_q)
m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),))
do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),))
dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q),
slice(None)), eviction_policy="evict_last")
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved

def inner_loop(start_k, carry):
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
dq = carry
span_k = start_k * block_k + jnp.arange(block_k)
k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None)))

qk = pl.dot(q, k.T)
qk = qk.astype(q_ref.dtype)
qk = qk.astype(jnp.float32)
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
if sm_scale != 1.0:
qk *= sm_scale
if causal:
qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf'))
p = jnp.exp(qk - m[:, None])
dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None]
dp = dp + pl.dot(do, v.T)
ds = p * dp
if sm_scale != 1.0:
ds = ds * sm_scale
dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype)
return dq
if causal:
upper_bound = lax.div(start_q * block_q, block_k) + 1
else:
upper_bound = jt.cdiv(seq_len, block_k)
dq = lax.fori_loop(0, upper_bound, inner_loop, dq)
pl.store(dq_ref, (pl.ds(start_q * block_q, block_q),
slice(None)), dq, eviction_policy="evict_last")
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need eviction policy here



def mha_backward_kernel_dkv(
# Inputs
q_ref, k_ref, v_ref, out_ref, do_scaled_ref,
l_ref, m_ref, delta_ref,
# Outputs
dk_ref, dv_ref,
*, sm_scale: float, causal: bool,
block_q: int, block_d: int, block_k: int
):
del out_ref, l_ref # Not needed
seq_len = q_ref.shape[0]
start_k = pl.program_id(0)

dv = jnp.zeros([block_k, block_d], dtype=jnp.float32)
dk = jnp.zeros([block_k, block_d], dtype=jnp.float32)
k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
span_k = start_k * block_k + jnp.arange(block_k)
dv = jnp.zeros([block_k, block_d], dtype=jnp.float32)
dk = jnp.zeros([block_k, block_d], dtype=jnp.float32)
k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
span_k = start_k * block_k + jnp.arange(block_k)

def inner_loop(start_q, carry):
def inner_loop(start_q, carry):
dv, dk = carry
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
qk = pl.dot(q, k.T)
Expand All @@ -274,23 +324,18 @@ def inner_loop(start_q, carry):
if sm_scale != 1.0:
ds = ds * sm_scale
dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q)
dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q),
slice(None)), eviction_policy="evict_last")
dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype)
pl.store(dq_ref, (pl.ds(start_q * block_q, block_q),
slice(None)), dq, eviction_policy="evict_last")
return dv, dk
if causal:
if causal:
lower_bound = lax.div(start_k * block_k, block_q)
else:
else:
lower_bound = 0
dv, dk = lax.fori_loop(lower_bound, jt.cdiv(seq_len, block_q), inner_loop,
dv, dk = lax.fori_loop(lower_bound, jt.cdiv(seq_len, block_q), inner_loop,
(dv, dk))
pl.store(dv_ref, (pl.ds(start_k * block_k, block_k),
pl.store(dv_ref, (pl.ds(start_k * block_k, block_k),
slice(None)), dv.astype(dv_ref.dtype))
pl.store(dk_ref, (pl.ds(start_k * block_k, block_k),
pl.store(dk_ref, (pl.ds(start_k * block_k, block_k),
slice(None)), dk.astype(dk_ref.dtype))
lax.fori_loop(0, jt.cdiv(seq_len, block_k), outer_loop, None)


def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
backward_pass_impl: str, num_warps: Optional[int],
Expand All @@ -310,50 +355,76 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
elif backward_pass_impl == "triton":
# We accumulate into dq so we need to initialize it to zeros.
dq = jnp.zeros(q.shape, jnp.float32)
out_shapes = [
jax.ShapeDtypeStruct(dq.shape, dq.dtype),
jax.ShapeDtypeStruct(k.shape, k.dtype),
jax.ShapeDtypeStruct(v.shape, v.dtype),
]
out_shapes_q = jax.ShapeDtypeStruct(dq.shape, dq.dtype)

grid = (batch_size, num_heads)
grid_q = (jt.cdiv(seq_len, block_q), batch_size, num_heads)
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
# grid_q = (batch_size, num_heads)
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
# TODO(sharadmv): figure out why num_warps=8 doesn't work!
num_warps = 4
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
dq, dk, dv = pl.pallas_call(
functools.partial(mha_backward_kernel, block_q=block_q, block_d=head_dim,
dq = pl.pallas_call(
functools.partial(mha_backward_kernel_dq, block_q=block_q, block_d=head_dim,
block_k=block_k, sm_scale=sm_scale, causal=causal),
grid=grid,
out_shape=out_shapes,
grid=grid_q,
out_shape=out_shapes_q,
in_specs=[
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
out_specs=[
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
name="mha_backward",
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
debug=debug,
interpret=interpret,
num_warps=num_warps,
num_stages=1,
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq)
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved

grid_kv = (jt.cdiv(seq_len, block_k), batch_size, num_heads)
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
out_shapes_kv = [
jax.ShapeDtypeStruct(k.shape, k.dtype),
jax.ShapeDtypeStruct(v.shape, v.dtype),
]
dk, dv = pl.pallas_call(
functools.partial(mha_backward_kernel_dkv, block_q=block_q, block_d=head_dim,
block_k=block_k, sm_scale=sm_scale, causal=causal),
grid=grid_kv,
out_shape=out_shapes_kv,
in_specs=[
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
],
out_specs=[
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
],
name="mha_backward",
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved
debug=debug,
interpret=interpret,
num_warps=num_warps,
num_stages=1)(q, k, v, out, do_scaled, l, m, delta)
tonywu95 marked this conversation as resolved.
Show resolved Hide resolved

else:
raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}")
return dq.astype(q.dtype), dk, dv
mha.defvjp(_mha_forward, _mha_backward)


@functools.partial(jax.jit, static_argnames=['sm_scale', 'causal'])
def mha_reference(q, k, v, sm_scale=1.0, causal: bool = False):
def mha_reference(q, k, v, sm_scale=1.0, causal: bool = True):
q_seq_len = q.shape[1]
kv_seq_len = k.shape[1]
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k).astype(jnp.float32)
Expand Down