From 2d9aa3882f21a30b8966a508bb42bc25b4c94c60 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 8 May 2023 18:16:45 +0000 Subject: [PATCH 1/3] Rebase onto latest --- jax_triton/pallas/triton_ir_lowering.py | 2 +- pyproject.toml | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/jax_triton/pallas/triton_ir_lowering.py b/jax_triton/pallas/triton_ir_lowering.py index 6e6afbee..981de4d6 100644 --- a/jax_triton/pallas/triton_ir_lowering.py +++ b/jax_triton/pallas/triton_ir_lowering.py @@ -137,7 +137,7 @@ def lower_jaxpr_to_triton_module(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: Gr assert len(jaxpr.outvars) == 0 prototype = tl.function_type([], arg_types) out = prototype.to_ir(builder) - fn = builder.get_or_insert_function(module, name, out, "public") + fn = builder.get_or_insert_function(module, name, out, "public", False) module.push_back(fn) entry = fn.add_entry_block() args = [] diff --git a/pyproject.toml b/pyproject.toml index 0b588943..22806e67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,6 @@ readme = "README.md" requires-python = ">=3.7,<3.11" dependencies = [ "absl-py>=1.4.0", - "jax>=0.4.2", - "triton==2.0.0a2", ] [build-system] From 02184d7ce3566156d4866c06f5f89e363036844e Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Sat, 11 Feb 2023 19:18:21 +0000 Subject: [PATCH 2/3] Initial AD impl --- jax_triton/pallas/pallas_call.py | 319 ++++++++++++++++++++++++++++++- tests/pallas_test.py | 18 +- 2 files changed, 325 insertions(+), 12 deletions(-) diff --git a/jax_triton/pallas/pallas_call.py b/jax_triton/pallas/pallas_call.py index c25c0975..ea846971 100644 --- a/jax_triton/pallas/pallas_call.py +++ b/jax_triton/pallas/pallas_call.py @@ -15,6 +15,7 @@ """Module for calling pallas functions from JAX.""" from functools import partial import itertools as it +import operator as op from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple, Union @@ -30,11 +31,12 @@ from jax._src import ad_util from jax._src import core as jax_core from jax._src.lib.mlir.dialects import mhlo +from jax._src import source_info_util from jax._src import state from jax._src.state import discharge as state_discharge from jax._src.util import ( split_list, safe_map, safe_zip, weakref_lru_cache, - tuple_insert, partition_list) + tuple_insert, partition_list, merge_lists) from jax._src.lax.control_flow import for_loop import jax.numpy as jnp import numpy as np @@ -141,6 +143,8 @@ def _pallas_call_abstract_eval(*avals, out_shapes, **_): def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, input_output_aliases: Tuple[Tuple[int, int], ...], in_shapes, out_shapes, grid_spec, debug, interpret, **compiler_params: Any): + num_inputs = len(in_shapes) + num_outputs = len(out_shapes) if input_output_aliases: raise NotImplementedError("JVP with aliasing not supported.") nonzero_tangents = [not isinstance(t, ad_util.Zero) for t in tangents] @@ -151,7 +155,7 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ()) jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, []) jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts - jvp_which_linear = which_linear + (True,) * len(tangents) + jvp_which_linear = (*which_linear, *(True,) * len(tangents)) jvp_inshapes = (*in_shapes, *in_shapes) jvp_outshapes = (*out_shapes, *out_shapes) if input_output_aliases: @@ -190,6 +194,316 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, return out_primals, out_tangents ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule +_save_everything = lambda *_, **__: True + +def _convert_outputs_to_writes( + jaxpr: jax_core.Jaxpr, + ) -> Tuple[jax_core.Jaxpr, list[jax_core.ShapedArray]]: + assert not jaxpr.constvars, "Jaxpr shouldn't have constvars." + + in_avals = [v.aval for v in jaxpr.invars] # [*orig_ref_avals] + @lu.wrap_init + def eval_jaxpr(*refs): + # We split the refs into the original input refs and the dummy residual + # refs. + orig_refs, residual_refs = split_list(refs, [len(in_avals)]) + residual_vals = jax_core.eval_jaxpr(jaxpr, (), *orig_refs) + for res_ref, res_val in zip(residual_refs, residual_vals): + res_ref[()] = res_val + return [] + res_ref_avals = [state.ShapedArrayRef(v.aval.shape, v.aval.dtype) # pytype: disable=attribute-error + for v in jaxpr.outvars] + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic( + eval_jaxpr, [*in_avals, *res_ref_avals]) + assert not consts + return jaxpr, [jax_core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] + +def _convert_inputs_to_reads(num_res: int, jaxpr: jax_core.Jaxpr + ) -> jax_core.Jaxpr: + assert not jaxpr.constvars, "Jaxpr should not have constvars" + + @lu.wrap_init + def eval_jaxpr(*refs): + residual_refs, orig_refs = split_list(refs, [num_res]) + residual_vals = [r[()] for r in residual_refs] + () = jax_core.eval_jaxpr(jaxpr, (), *residual_vals, *orig_refs) + return [] + + res_val_avals, orig_ref_avals = split_list([v.aval for v in jaxpr.invars], [num_res]) + res_ref_avals = [state.ShapedArrayRef(aval.shape, aval.dtype) + for aval in res_val_avals] + + jaxpr, _, () = pe.trace_to_jaxpr_dynamic( + eval_jaxpr, [*res_ref_avals, *orig_ref_avals]) + return jaxpr + +def _pallas_call_partial_eval( + trace: pe.JaxprTrace, + *tracers: pe.JaxprTracer, + jaxpr: jax_core.Jaxpr, + name: str, + in_shapes: tuple[jax.ShapeDtypeStruct, ...], + out_shapes: tuple[jax.ShapeDtypeStruct, ...], + grid_spec: pallas_core.GridSpec, + which_linear: tuple[bool, ...], + interpret: bool, + debug: bool, + input_output_aliases: tuple[tuple[int, int], ...], + **compiler_params: Any): + if input_output_aliases: + raise NotImplementedError + num_inputs = len(in_shapes) + num_outputs = len(out_shapes) + assert num_inputs + num_outputs == len(jaxpr.invars) + in_unknowns = [not t.pval.is_known() for t in tracers] + out_unknowns = [False] * num_outputs + # We first need to run a fixpoint to determine which of the `Ref`s are unknown + # after running the for loop. We want to use the jaxpr to determine which + # `Ref`s are unknown after executing the for loop body given which `Ref`s are + # unknown before. However, the jaxpr has no outputs. Instead, we discharge + # the body and run the fixpoint with the discharged jaxpr. We can do this + # because the outputs of the jaxpr are one-to-one with the inputs. + all_in_unknowns = [*in_unknowns, *out_unknowns] + discharged_jaxpr, discharged_consts = state.discharge_state(jaxpr, ()) + discharged_jaxpr = discharged_jaxpr.replace( + invars=discharged_jaxpr.constvars + discharged_jaxpr.invars, + constvars=[]) + for _ in range(num_inputs + num_outputs): + jaxpr_in_unknowns = [False] * len(discharged_consts) + all_in_unknowns + _, _, all_out_unknowns, _, _, = pe.partial_eval_jaxpr_custom( + discharged_jaxpr, jaxpr_in_unknowns, [True] * len(jaxpr_in_unknowns), + all_in_unknowns, False, _save_everything) + all_out_unknowns = list(all_out_unknowns) + if all_out_unknowns == all_in_unknowns: + break + all_in_unknowns = map(op.or_, all_in_unknowns, all_out_unknowns) + else: + raise Exception("Invalid fixpoint") + all_unknowns = all_in_unknowns + del all_in_unknowns, all_out_unknowns # redundant since it's the same as `in_unknowns` + in_unknowns, out_unknowns = split_list(all_unknowns, [num_inputs]) + + tracers = tuple(trace.instantiate_const(t) if uk else t # type: ignore + for t, uk in zip(tracers, in_unknowns)) + + # We use `partial_eval_jaxpr_custom` here because it won't remove effectful + # primitives like `get`/`set`. + jaxpr_known_resout, jaxpr_unknown_resin_, uk_out, inst_out, num_res = \ + pe.partial_eval_jaxpr_custom( + jaxpr, + in_inst=all_unknowns, + in_unknowns=all_unknowns, + ensure_out_unknowns=[], + ensure_out_inst=[], + saveable=_save_everything) + # # `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and + # regular valued input/outputs. However, we'd like to bind these jaxprs to a + # `for`, which expects only `Ref` inputs and no output. We need to convert + # both of these jaxprs into ones that are compatible with `for`. + # TODO(sharadmv,mattjj): implement "passthrough" optimization. + # TODO(sharadmv,mattjj): rematerialize loop-dependent values instead of + # passing the loop index as a residual + + # `jaxpr_known_resout` is a jaxpr that maps from all the input `Refs` + # to output residual values (none of them should be `Ref`s). We'll need to + # convert the output residual values into `Ref`s that are initially empty + # `Ref`s that are written to at the end of the jaxpr. + jaxpr_known, res_avals = _convert_outputs_to_writes(jaxpr_known_resout) + jaxpr_unknown = _convert_inputs_to_reads(num_res, jaxpr_unknown_resin_) + + # Now we execute the forward pass that returns known outputs and residuals + grid, block_mappings, mapped_dims = ( + grid_spec.grid, grid_spec.block_mappings, grid_spec.mapped_dims) + in_block_mappings, out_block_mappings = split_list(block_mappings, + [num_inputs]) + known_in_block_mappings, unknown_in_block_mappings = partition_list( + in_unknowns, in_block_mappings) + known_out_block_mappings, unknown_out_block_mappings = partition_list( + out_unknowns, out_block_mappings) + known_in_shapes, unknown_in_shapes = partition_list(in_unknowns, + in_shapes) + known_out_shapes, unknown_out_shapes = partition_list(out_unknowns, + out_shapes) + known_which_linear, unknown_which_linear = partition_list(in_unknowns, + which_linear) + res_which_linear = (False,) * num_res + known_tracers, unknown_tracers = partition_list(in_unknowns, tracers) + known_vals = [t.pval.get_known() for t in known_tracers] + res_shapes = [jax.ShapeDtypeStruct((*grid, *a.shape), a.dtype) + for a in res_avals] + res_index_mappings = [ + jax_core.ClosedJaxpr( + pe.trace_to_jaxpr_dynamic( + lu.wrap_init(lambda *args: (*args, *[0] * len(a.shape))), + [jax_core.ShapedArray((), jnp.int32)] *len(grid))[0], ()) + for a in res_avals + ] + res_block_mappings = [ + BlockMapping((*[None] * len(grid), *a.shape), index_map) + for a, index_map in zip(res_avals, res_index_mappings) + ] + known_grid_spec = GridSpec(grid, (*known_in_block_mappings, + *known_out_block_mappings, + *res_block_mappings), + grid_spec.mapped_dims) + unknown_grid_spec = GridSpec(grid, (*res_block_mappings, + *unknown_in_block_mappings, + *unknown_out_block_mappings), + grid_spec.mapped_dims) + known_out_and_res = pallas_call_p.bind( + *known_vals, + jaxpr=jaxpr_known, + grid_spec=known_grid_spec, + in_shapes=tuple(known_in_shapes), + out_shapes=(*known_out_shapes, *res_shapes), + interpret=interpret, + debug=debug, + name=f"{name}_known", + input_output_aliases=(), + which_linear=tuple(known_which_linear), + **compiler_params) + known_outputs, residuals = split_list(known_out_and_res, [len(known_tracers)]) + residuals = map(trace.new_instantiated_const, residuals) + unknown_inputs = [*residuals, *unknown_tracers] + unknown_outputs = [ + pe.JaxprTracer(trace, pe.PartialVal.unknown(jax_core.ShapedArray(s.shape, + s.dtype)), None) + for s in unknown_out_shapes] + name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] + source = source_info_util.current().replace(name_stack=name_stack) + unknown_params = dict( + jaxpr=jaxpr_unknown, + in_shapes=(*(jax.ShapeDtypeStruct(s.shape, s.dtype) for s in res_avals), + *unknown_in_shapes), + out_shapes=tuple(unknown_out_shapes), + grid_spec=unknown_grid_spec, + which_linear=(*res_which_linear, *unknown_which_linear), + debug=debug, + interpret=interpret, + name=f"{name}_unknown", + input_output_aliases=(), + **compiler_params) + eqn = pe.new_eqn_recipe(unknown_inputs, unknown_outputs, + pallas_call_p, unknown_params, + jax_core.no_effects, source) + for t in unknown_outputs: t.recipe = eqn + return merge_lists(out_unknowns, known_outputs, unknown_outputs) +pe.custom_partial_eval_rules[pallas_call_p] = _pallas_call_partial_eval + +def _transpose_jaxpr(jaxpr: jax_core.Jaxpr, which_linear: Sequence[bool] + ) -> jax_core.Jaxpr: + num_inputs = len(which_linear) + num_outputs = len(jaxpr.invars) - num_inputs + def trans(*args): + # First we want to run the computation to read all the residual refs. We can + # do that by using partial evaluation with all linear inputs unknown. + res_jaxpr, tangent_jaxpr_, *_ = \ + pe.partial_eval_jaxpr_custom(jaxpr, + in_unknowns=[*which_linear, *[True] * + num_outputs], + in_inst=[*which_linear, *[True] * + num_outputs], + ensure_out_inst=[], + ensure_out_unknowns=[], + saveable=_save_everything) + res_args = [x for x, lin in zip(args, which_linear) if not lin] + res = jax_core.eval_jaxpr(res_jaxpr, (), *res_args) + + # Now that we have residual values, we run the tangent jaxpr. It takes as + # input the residuals, and all the refs (at least, the ones + # that are used in the body). Luckily, `tangent_jaxpr_` has all known and + # unknown inputs! + breakpoint() + primals_args = [*(r for u, r in zip(used_res, res) if u)] + ct_args = [x for x, u in zip(args, used_ct) if u] + ad.backward_pass( + tangent_jaxpr, (), False, (), (*res, *ct_args), ()) + breakpoint() + return [] + jaxpr_trans, _, _ = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(trans), [v.aval for v in jaxpr.invars]) + return jaxpr_trans + +def _pallas_call_transpose_rule(cts_in, *args, + jaxpr: jax_core.Jaxpr, + name: str, + in_shapes: Tuple[jax.ShapeDtypeStruct, ...], + out_shapes: Tuple[jax.ShapeDtypeStruct, ...], + grid_spec: GridSpec, + input_output_aliases: Tuple[Tuple[int, int], ...], + debug: bool, + interpret: bool, + which_linear: Tuple[bool, ...], + **compiler_params: Any): + num_inputs = len(in_shapes) + num_outputs = len(out_shapes) + is_undefined_primal = [ad.is_undefined_primal(x) for x in args] + defined_primals, undefined_primals = partition_list(is_undefined_primal, args) + defined_in_shapes, undefined_in_shapes = partition_list(is_undefined_primal, + in_shapes) + block_mappings = grid_spec.block_mappings + in_block_mappings, out_block_mappings = split_list(block_mappings, + [num_inputs]) + defined_in_block_mappings, undefined_in_block_mappings = partition_list( + is_undefined_primal, in_block_mappings) + defined_which_linear, undefined_which_linear = partition_list( + is_undefined_primal, which_linear) + defined_in_shapes, undefined_in_shapes = partition_list(is_undefined_primal, + in_shapes) + num_undefined_inputs = sum(is_undefined_primal) + num_defined_inputs = num_inputs - num_undefined_inputs + def trans(*args): + defined_primals, cts, undefined_primals = split_list(args, + [num_defined_inputs, + num_outputs]) + # First we want to run the computation to read all the residual refs. We can + # do that by using partial evaluation with all linear inputs unknown. + res_jaxpr, tangent_jaxpr_, *_ = \ + pe.partial_eval_jaxpr_custom(jaxpr, + in_unknowns=[*is_undefined_primal, *[True] * + num_outputs], + in_inst=[*is_undefined_primal, *[True] * + num_outputs], + ensure_out_inst=[], + ensure_out_unknowns=[], + saveable=_save_everything) + res = jax_core.eval_jaxpr(res_jaxpr, (), *defined_primals) + + # Now that we have residual values, we run the tangent jaxpr. It takes as + # input the residuals, and all the refs (at least, the ones + # that are used in the body). Luckily, `tangent_jaxpr_` has all known and + # unknown inputs! + ad.backward_pass( + tangent_jaxpr_, (), False, (), (*res, *undefined_primals, *cts), ()) + return [] + jaxpr_avals = [v.aval for v in jaxpr.invars] + jaxpr_in_avals, jaxpr_out_avals = split_list(jaxpr_avals, [num_inputs]) + jaxpr_defined_in_avals, jaxpr_undefined_in_avals = partition_list( + is_undefined_primal, jaxpr_in_avals) + jaxpr_trans, _, _ = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(trans), [*jaxpr_defined_in_avals, *jaxpr_out_avals, + *jaxpr_undefined_in_avals]) + grid_spec = GridSpec( + grid_spec.grid, (*defined_in_block_mappings, *out_block_mappings, + *undefined_in_block_mappings), + grid_spec.mapped_dims) + cts_out = pallas_call_p.bind( + *defined_primals, *cts_in, + jaxpr=jaxpr_trans, + grid_spec=grid_spec, + in_shapes=(*defined_in_shapes, *out_shapes), + out_shapes=tuple(undefined_in_shapes), + name=f"{name}_transpose", + debug=debug, + interpret=interpret, + which_linear=(*defined_which_linear, *[True] * num_outputs), + input_output_aliases=(), + **compiler_params) + cts_out_iter = iter(cts_out) + return [next(cts_out_iter) if ud else None for + ud in is_undefined_primal] +ad.primitive_transposes[pallas_call_p] = _pallas_call_transpose_rule + def _batch_block_mapping(grid: Tuple[int, ...], aval: jax_core.ShapedArray, dim: Union[int, batching.NotMapped], block_mapping: Optional[BlockMapping]) -> BlockMapping: @@ -347,7 +661,6 @@ def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False, out_specs = tuple(out_specs) flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) for x in flat_out_shapes] - @jax.jit def wrapped(*args): flat_args, in_tree = tree_util.tree_flatten(args) if grid is None: diff --git a/tests/pallas_test.py b/tests/pallas_test.py index 29964d0b..49738f8f 100644 --- a/tests/pallas_test.py +++ b/tests/pallas_test.py @@ -752,15 +752,15 @@ def pallas_impl(x_ref, o_ref): rtol=1e-5) jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2) - # TODO(sharadmv): enable this when we update Triton - # def test_jvp_matmul(self): - # k1, k2 = random.split(random.PRNGKey(0)) - # x = random.normal(k1, (256, 128)) - # y = random.normal(k2, (128, 64)) - # bm, bn, bk, gm = 64, 128, 32, 8 - # mm = functools.partial(matmul, bm=bm, bn=bn, bk=bk, gm=gm, - # interpret=self.INTERPRET) - # jtu.check_grads(mm, (x, y), modes=["fwd"], order=1) + TODO(sharadmv): enable this when we update Triton + def test_jvp_matmul(self): + k1, k2 = random.split(random.PRNGKey(0)) + x = random.normal(k1, (256, 128)) + y = random.normal(k2, (128, 64)) + bm, bn, bk, gm = 64, 128, 32, 8 + mm = functools.partial(matmul, bm=bm, bn=bn, bk=bk, gm=gm, + interpret=self.INTERPRET) + jtu.check_grads(mm, (x, y), modes=["fwd"], order=1) def test_slicing_block_spec(self): @functools.partial( From 1b22199ed59edd30cd04d09afb278c3492d14c8b Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 13 Feb 2023 18:05:37 +0000 Subject: [PATCH 3/3] WIP AD --- jax_triton/pallas/pallas_call.py | 179 +++++++++++++++++++++++-------- tests/pallas_test.py | 25 ++++- 2 files changed, 156 insertions(+), 48 deletions(-) diff --git a/jax_triton/pallas/pallas_call.py b/jax_triton/pallas/pallas_call.py index ea846971..d1d2d170 100644 --- a/jax_triton/pallas/pallas_call.py +++ b/jax_triton/pallas/pallas_call.py @@ -156,7 +156,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, jvp_jaxpr_, _ = ad.jvp_jaxpr(closed_jaxpr, nonzero_tangents_with_outputs, []) jvp_jaxpr, () = jvp_jaxpr_.jaxpr, jvp_jaxpr_.consts # TODO consts jvp_which_linear = (*which_linear, *(True,) * len(tangents)) - jvp_inshapes = (*in_shapes, *in_shapes) + _, nonzero_tangent_in_shapes = partition_list(nonzero_tangents, in_shapes) + jvp_inshapes = (*in_shapes, *nonzero_tangent_in_shapes) jvp_outshapes = (*out_shapes, *out_shapes) if input_output_aliases: raise NotImplementedError("`input_output_aliases` jvp not supported.") @@ -172,7 +173,8 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear, logical_primal_inputs, logical_primal_outputs = split_list(logical_primals, [len(primals)]) logical_tangent_inputs, logical_tangent_outputs = split_list(logical_tangents, [len(tangents)]) in_bms, out_bms = split_list(grid_spec.block_mappings, [len(primals)]) - new_bms = tuple((*in_bms, *in_bms, *out_bms, *out_bms)) + nonzero_in_bms, _ = partition_list(nonzero_tangents, in_bms) + new_bms = tuple((*in_bms, *nonzero_in_bms, *out_bms, *out_bms)) new_grid_spec = grid_spec.replace(block_mappings=new_bms) jvp_jaxpr = jvp_jaxpr.replace(invars=[*logical_primal_inputs, *logical_tangent_inputs, @@ -291,12 +293,13 @@ def _pallas_call_partial_eval( jaxpr_known_resout, jaxpr_unknown_resin_, uk_out, inst_out, num_res = \ pe.partial_eval_jaxpr_custom( jaxpr, - in_inst=all_unknowns, + in_inst=True, in_unknowns=all_unknowns, ensure_out_unknowns=[], ensure_out_inst=[], saveable=_save_everything) - # # `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and + breakpoint() + # `partial_eval_jaxpr_custom` will give us jaxprs that have hybrid `Ref` and # regular valued input/outputs. However, we'd like to bind these jaxprs to a # `for`, which expects only `Ref` inputs and no output. We need to convert # both of these jaxprs into ones that are compatible with `for`. @@ -339,13 +342,13 @@ def _pallas_call_partial_eval( for a in res_avals ] res_block_mappings = [ - BlockMapping((*[None] * len(grid), *a.shape), index_map) + BlockMapping((*[pallas_core.mapped] * len(grid), *a.shape), index_map) for a, index_map in zip(res_avals, res_index_mappings) ] known_grid_spec = GridSpec(grid, (*known_in_block_mappings, *known_out_block_mappings, *res_block_mappings), - grid_spec.mapped_dims) + mapped_dims) unknown_grid_spec = GridSpec(grid, (*res_block_mappings, *unknown_in_block_mappings, *unknown_out_block_mappings), @@ -362,7 +365,7 @@ def _pallas_call_partial_eval( input_output_aliases=(), which_linear=tuple(known_which_linear), **compiler_params) - known_outputs, residuals = split_list(known_out_and_res, [len(known_tracers)]) + known_outputs, residuals = split_list(known_out_and_res, [len(known_out_shapes)]) residuals = map(trace.new_instantiated_const, residuals) unknown_inputs = [*residuals, *unknown_tracers] unknown_outputs = [ @@ -373,8 +376,7 @@ def _pallas_call_partial_eval( source = source_info_util.current().replace(name_stack=name_stack) unknown_params = dict( jaxpr=jaxpr_unknown, - in_shapes=(*(jax.ShapeDtypeStruct(s.shape, s.dtype) for s in res_avals), - *unknown_in_shapes), + in_shapes=(*res_shapes, *unknown_in_shapes), out_shapes=tuple(unknown_out_shapes), grid_spec=unknown_grid_spec, which_linear=(*res_which_linear, *unknown_which_linear), @@ -390,40 +392,6 @@ def _pallas_call_partial_eval( return merge_lists(out_unknowns, known_outputs, unknown_outputs) pe.custom_partial_eval_rules[pallas_call_p] = _pallas_call_partial_eval -def _transpose_jaxpr(jaxpr: jax_core.Jaxpr, which_linear: Sequence[bool] - ) -> jax_core.Jaxpr: - num_inputs = len(which_linear) - num_outputs = len(jaxpr.invars) - num_inputs - def trans(*args): - # First we want to run the computation to read all the residual refs. We can - # do that by using partial evaluation with all linear inputs unknown. - res_jaxpr, tangent_jaxpr_, *_ = \ - pe.partial_eval_jaxpr_custom(jaxpr, - in_unknowns=[*which_linear, *[True] * - num_outputs], - in_inst=[*which_linear, *[True] * - num_outputs], - ensure_out_inst=[], - ensure_out_unknowns=[], - saveable=_save_everything) - res_args = [x for x, lin in zip(args, which_linear) if not lin] - res = jax_core.eval_jaxpr(res_jaxpr, (), *res_args) - - # Now that we have residual values, we run the tangent jaxpr. It takes as - # input the residuals, and all the refs (at least, the ones - # that are used in the body). Luckily, `tangent_jaxpr_` has all known and - # unknown inputs! - breakpoint() - primals_args = [*(r for u, r in zip(used_res, res) if u)] - ct_args = [x for x, u in zip(args, used_ct) if u] - ad.backward_pass( - tangent_jaxpr, (), False, (), (*res, *ct_args), ()) - breakpoint() - return [] - jaxpr_trans, _, _ = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(trans), [v.aval for v in jaxpr.invars]) - return jaxpr_trans - def _pallas_call_transpose_rule(cts_in, *args, jaxpr: jax_core.Jaxpr, name: str, @@ -592,6 +560,105 @@ def _pallas_call_batching_rule(args, dims, *, return out, (0,) * len(out) batching.primitive_batchers[pallas_call_p] = _pallas_call_batching_rule +class TritonCompilationResult(NamedTuple): + name: str + asm: Dict[str, str] + shared_mem: int + lowering_result: lowering.TritonLoweringResult + +@weakref_lru_cache +def _compile_jaxpr(jaxpr: jax_core.Jaxpr, in_shapes, grid_spec: GridSpec, + name: str, num_warps: int, num_stages: int + ) -> TritonCompilationResult: + lowering_result = lowering.lower_jaxpr_to_triton_module(jaxpr, in_shapes, grid_spec, name) + backend = tc.runtime.backend.CUDA + device = 0 + name, asm, shared_mem = tc.code_gen.compile_ttir(backend, lowering_result.module, device, + num_warps, num_stages, {}, 0) + return TritonCompilationResult(name, asm, shared_mem, lowering_result) + + +def pallas_call_lowering(ctx: mlir.LoweringRuleContext, *in_nodes, + jaxpr: jax_core.Jaxpr, + name: str, + in_shapes: Tuple[jax.ShapeDtypeStruct, ...], + out_shapes: Tuple[jax.ShapeDtypeStruct, ...], + which_linear: Tuple[bool, ...], + interpret: bool, + debug: bool, + input_output_aliases: Tuple[Tuple[int, int], ...], + grid_spec: GridSpec, + **compiler_params: Any): + if interpret: + return mlir.lower_fun(_pallas_call_impl, multiple_results=True)( + ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes, + in_shapes=in_shapes, + which_linear=which_linear, + interpret=interpret, debug=debug, + input_output_aliases=input_output_aliases, + grid_spec=grid_spec, **compiler_params) + num_warps = compiler_params.get("num_warps", 4) + num_stages = compiler_params.get("num_stages", 3) + compilation_result = _compile_jaxpr(jaxpr, tuple((*in_shapes, *out_shapes)), + grid_spec, name, num_warps, num_stages) + name = compilation_result.name + asm = compilation_result.asm + shared_mem = compilation_result.shared_mem + ref_effects = state.get_ref_state_effects( + [v.aval for v in jaxpr.invars], jaxpr.effects) + is_accum = [ + all(isinstance(eff, state.AccumEffect) for eff in ref_effect) + for ref_effect in ref_effects + ] + if debug: + print(jaxpr) + print(grid_spec) + lowering_result = compilation_result.lowering_result + if debug: + lowering_result.module.print() + out_type = ir.TupleType.get_tuple([ + ir.RankedTensorType.get(out_shape.shape, mlir.dtype_to_ir_type(out_shape.dtype)) + for out_shape in ctx.avals_out]) + i32_type = ir.IntegerType.get_signless(32) + + kernel = triton_kernel_call_lib.TritonKernel( + asm["cubin"], name, num_warps, shared_mem + ) + + grid = normalize_grid(compilation_result.lowering_result.grid, metaparams={}) + # All arguments are buffers. + all_args = [None] * (len(in_shapes) + len(out_shapes)) + kernel_call = triton_kernel_call_lib.TritonKernelCall( + kernel, grid[0], grid[1], grid[2], all_args, + is_accum, + [s.size for s in [*in_shapes, *out_shapes]] + ) + + ctx.module_context.add_keepalive(kernel_call) + output_operand_aliases = ir.ArrayAttr.get([ + mhlo.OutputOperandAlias.get( + output_tuple_indices=[output], + operand_index=input, + operand_tuple_indices=[]) + for input, output in input_output_aliases + ]) + out = mhlo.CustomCallOp( + [out_type], + in_nodes, + call_target_name=ir.StringAttr.get("triton_kernel_call"), + has_side_effect=ir.BoolAttr.get(False), + backend_config=ir.StringAttr.get(kernel_call.descriptor), + api_version=ir.IntegerAttr.get(i32_type, 1), + called_computations=ir.ArrayAttr.get([]), + operand_layouts=avals_to_layouts(ctx.avals_in), + result_layouts=avals_to_layouts(ctx.avals_out), + output_operand_aliases=output_operand_aliases, + ) + results = [mhlo.GetTupleElementOp(out, mlir.i32_attr(i)).result + for i in range(len(out_shapes))] + return results +mlir.register_lowering(pallas_call_p, pallas_call_lowering, platform="cuda") + @weakref_lru_cache def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, primitive_name: Optional[str] = None): @@ -633,6 +700,32 @@ def _compute_shape_from_block_spec(block_spec: Optional[BlockSpec], return arg_shape return tuple(s for s in block_spec.block_shape if s is not None) +def _pallas_call_bind(*args, + jaxpr: jax_core.Jaxpr, + name: str, + in_shapes: Tuple[jax.ShapeDtypeStruct, ...], + out_shapes: Tuple[jax.ShapeDtypeStruct, ...], + which_linear: Tuple[bool, ...], + interpret: bool, + debug: bool, + input_output_aliases: Tuple[Tuple[int, int], ...], + grid_spec: GridSpec, + **compiler_params: Any): + num_inputs = len(in_shapes) + num_outputs = len(out_shapes) + assert len(jaxpr.invars) == num_inputs + num_outputs, (len(jaxpr.invars), + num_inputs, + num_outputs) + assert len(grid_spec.block_mappings) == len(jaxpr.invars) + return jax_core.Primitive.bind( + pallas_call_p, *args, + jaxpr=jaxpr, name=name, in_shapes=in_shapes, + out_shapes=out_shapes, which_linear=which_linear, + interpret=interpret, debug=debug, + input_output_aliases=input_output_aliases, + grid_spec=grid_spec, **compiler_params) +pallas_call_p.def_custom_bind(_pallas_call_bind) + def pallas_call(f: Callable, out_shape: Any, *, debug: bool = False, grid: Optional[Grid] = None, in_specs: Optional[Sequence[Optional[BlockSpec]]] = None, diff --git a/tests/pallas_test.py b/tests/pallas_test.py index 49738f8f..49f95840 100644 --- a/tests/pallas_test.py +++ b/tests/pallas_test.py @@ -702,8 +702,7 @@ class PallasCallAutodifferentiationTest(PallasTest): ("square", lambda x: x * x), ("add_one", lambda x: x + 1.), ("exp", jnp.exp), - # ("tanh", jnp.tanh), TODO(sharadmv): re-enable this case when libdevice is - # updated + ("tanh", jnp.tanh), ]) def test_jvp(self, impl): @functools.partial( @@ -728,8 +727,7 @@ def pallas_impl(x_ref, o_ref): ("square", lambda x: x * x), ("add_one", lambda x: x + 1.), ("exp", jnp.exp), - # ("tanh", jnp.tanh), TODO(sharadmv): re-enable this case when libdevice is - # updated + ("tanh", jnp.tanh), ]) def test_jvp_slice(self, impl): @functools.partial( @@ -752,7 +750,6 @@ def pallas_impl(x_ref, o_ref): rtol=1e-5) jtu.check_grads(pallas_impl, (x,), modes=["fwd"], order=2) - TODO(sharadmv): enable this when we update Triton def test_jvp_matmul(self): k1, k2 = random.split(random.PRNGKey(0)) x = random.normal(k1, (256, 128)) @@ -778,6 +775,24 @@ def add_vectors(x_ref, y_ref, o_ref): out_ref = xy[0] + xy[1] np.testing.assert_allclose(out, out_ref) + @parameterized.named_parameters(*[ + ("square", lambda x: x * x), + ("add_one", lambda x: x + 1.), + ("exp", jnp.exp), + ("tanh", jnp.tanh), + ]) + def test_grad(self, impl): + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32)) + def pallas_impl(x_ref, o_ref): + o_ref[...] = impl(x_ref[...]) + + x = random.normal(random.PRNGKey(0)) + g = jax.grad(pallas_impl)(x) + g_ref = jax.grad(impl)(x) + np.testing.assert_allclose(g, g_ref, atol=1e-5, rtol=1e-5) + jtu.check_grads(pallas_impl, (x,), modes=["rev"], order=1) + class PallasCallVmapTest(PallasTest):