Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CPU backend #322

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions jax_triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"__version_info__",
]

from jax._src.lib import gpu_triton
from jax_triton import utils
from jax_triton.triton_lib import triton_call
from jax.experimental.pallas import cdiv
Expand All @@ -33,17 +32,4 @@
from jax_triton.version import __version__
from jax_triton.version import __version_info__

try:
get_compute_capability = gpu_triton.get_compute_capability
get_serialized_metadata = gpu_triton.get_serialized_metadata
except AttributeError:
raise ImportError(
"jax-triton requires JAX to be installed with GPU support. The "
"installation page on the JAX documentation website includes "
"instructions for installing a supported version:\n"
"https://jax.readthedocs.io/en/latest/installation.html"
)
else:
del gpu_triton # Not part of the API.

# trailer
194 changes: 172 additions & 22 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@
import triton.backends.amd.compiler as hb
except ImportError:
hb = None
pass

try:
import triton.backends.cpu.compiler as cpub

except ImportError:
cpub = None


os.environ["TRITON_CACHE_DIR"] = ""
Expand Down Expand Up @@ -170,6 +175,13 @@ def get_hip_backend(device, compute_capability):
backend = hb.HIPBackend(target)
return backend

def get_cpu_backend(device, compute_capability):
arch = _triton.llvm.get_cpu_tripple()
arch = arch.split("-")[0]
target = cpub.GPUTarget('cpu', arch, 0)
backend = cpub.CPUBackend(target)
return backend

@dataclasses.dataclass
class CompilationResult:
binary: str
Expand All @@ -181,8 +193,8 @@ class CompilationResult:

def compile_ttir_inplace(
ttir,
backend: [cb.CUDABackend | hb.HIPBackend],
options: [cb.CUDAOptions | hb.HIPOptions],
backend: cb.CUDABackend | hb.HIPBackend | cpub.CPUBackend,
options: cb.CUDAOptions | hb.HIPOptions | cpub.CPUOptions,
compute_capability,
platform
):
Expand All @@ -201,6 +213,13 @@ def compile_ttir_inplace(
options,
compute_capability,
)
elif platform == 'cpu':
return compile_ttir_to_asm_inplace(
ttir,
backend,
options,
compute_capability,
)
else:
raise ValueError(
"Unsupported device."
Expand Down Expand Up @@ -322,6 +341,70 @@ def compile_ttir_to_hsaco_inplace(
llir=llir,
)

def compile_ttir_to_asm_inplace(
ttir,
cpu_backend: cpub.CPUBackend,
cpu_options: cpub.CPUOptions,
compute_capability,
) -> CompilationResult:
if cpu_options.debug:
print(ttir)
try:
metadata = {}
opt_ttir = cpu_backend.make_ttir(ttir, metadata, cpu_options)
ttcir = cpu_backend.make_ttcir(
opt_ttir,
metadata,
cpu_options
)
except RuntimeError as e:
ttir.dump()
raise ValueError("TTIR->TTCIR pass failed!") from e
if cpu_options.debug:
print(ttcir)
try:
tttcir = cpu_backend.make_tttcir(
ttcir,
metadata,
cpu_options
)
except RuntimeError as e:
ttcir.dump()
raise ValueError("TTCIR->TTTCIR pass failed!") from e
if cpu_options.debug:
print(tttcir)
try:
llir = cpu_backend.make_llir(
tttcir,
metadata,
cpu_options
)
except RuntimeError as e:
tttcir.dump()
raise ValueError("TTTCIR->LLIR pass failed!") from e
shared_mem_bytes = metadata["shared"]
if cpu_options.debug:
print(llir)
asm = cpu_backend.make_asm(
llir,
metadata,
cpu_options
)
if cpu_options.debug:
print(asm)
name = metadata["name"]
cluster_dims = metadata["cluster_dims"]
tttcir = str(tttcir) if _JAX_TRITON_DUMP_DIR else None
llir = str(llir) if _JAX_TRITON_DUMP_DIR else None
return CompilationResult(
binary=asm,
name=name,
shared_mem_bytes=shared_mem_bytes,
cluster_dims=cluster_dims,
ttgir=tttcir,
llir=llir,
)

_COMPILED_KERNEL_CACHE = {} # TODO(cjfj): Convert to LRU cache?


Expand Down Expand Up @@ -690,6 +773,12 @@ def prune_configs(configs, named_args, **kwargs):
platform="rocm",
)

mlir.register_lowering(
triton_kernel_call_p,
functools.partial(triton_kernel_call_lowering, get_cpu_backend),
platform="cpu",
)

class ShapeDtype(Protocol):

@property
Expand Down Expand Up @@ -827,23 +916,84 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
if input_output_aliases is None:
input_output_aliases = {}

out_flat = triton_kernel_call_p.bind(
*array_args,
fn=kernel,
scalar_args=tuple(scalar_args),
name=name,
custom_call_target_name=custom_call_target_name,
out_shapes=tuple(flat_out_shapes),
grid=grid,
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
compute_capability=compute_capability,
enable_fp_fusion=enable_fp_fusion,
input_output_aliases=tuple(input_output_aliases.items()),
zeroed_outputs=zeroed_outputs,
debug=debug,
serialized_metadata=serialized_metadata,
**metaparams,
)
if triton.runtime.driver.active.get_current_target().backend != "cpu":
out_flat = triton_kernel_call_p.bind(
*array_args,
fn=kernel,
scalar_args=tuple(scalar_args),
name=name,
custom_call_target_name=custom_call_target_name,
out_shapes=tuple(flat_out_shapes),
grid=grid,
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
compute_capability=compute_capability,
enable_fp_fusion=enable_fp_fusion,
input_output_aliases=tuple(input_output_aliases.items()),
zeroed_outputs=zeroed_outputs,
debug=debug,
serialized_metadata=serialized_metadata,
**metaparams,
)
else:
if isinstance(kernel, autotuner.Autotuner):
for config in kernel.configs:
if config.pre_hook is not None:
raise NotImplementedError("`pre_hook` is not supported")

class Pointer:

def __init__(self, x):
self.x = x
self.dtype = x.dtype

def data_ptr(self):
return self.x.unsafe_buffer_pointer()

def to_triton_arg(arg):
if arg.ndim == 0:
dtypes = {
jnp.bool.dtype: bool,
jnp.int32.dtype: int,
jnp.int64.dtype: int,
jnp.float32.dtype: float,
jnp.float64.dtype: float,
}
if arg.dtype not in dtypes:
raise ValueError(f"Invalid argument {arg} with type {arg.dtype}.")
return dtypes[arg.dtype](arg)
else:
return Pointer(arg)

def callback(flat_args, outputs):
kernel[lambda meta: normalize_grid(grid, metaparams | meta)](
*map(to_triton_arg, flat_args),
*map(Pointer, outputs),
**metaparams,
)
return outputs

# FIXME(stephen-huan): doesn't take into account kernel's meta
config_zeroed_outputs = zeroed_outputs
if callable(zeroed_outputs):
config_zeroed_outputs = config_zeroed_outputs(metaparams)

output_input_aliases = {}
for input_idx, output_idx in input_output_aliases.items():
if output_idx in output_input_aliases:
# TODO(stephen-huan): not sure how to handle this properly
raise NotImplementedError(
"Multiple inputs aliased to the same output is not supported."
)
output_input_aliases[output_idx] = flat_args[input_idx]
if output_idx in config_zeroed_outputs:
flat_args[input_idx] = flat_args[input_idx].at[:].set(0)

out_shapes = tuple(flat_out_shapes)
outputs = [
output_input_aliases.get(i, jnp.zeros(shape.shape, shape.dtype))
for i, shape in enumerate(out_shapes)
]
out_flat = jax.pure_callback(callback, out_shapes, flat_args, outputs)
return tree_util.tree_unflatten(out_tree, out_flat)
2 changes: 1 addition & 1 deletion tests/cluster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_invalid_cluster_size(self):
_dummy_fn(jnp.empty((16,)))

def test_cluster_not_available(self):
if 'h100' in jax.devices()[0].device_kind.lower():
if 'h100' not in jax.devices()[0].device_kind.lower():
self.skipTest('Clusters available only on H100s.')

my_triton_call = functools.partial(jt.triton_call, num_ctas=2)
Expand Down
8 changes: 7 additions & 1 deletion tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@

config.parse_flags_with_absl()

try:
jt.get_compute_capability(0)
except AttributeError:
# TODO(stephen-huan): add in jaxlib
jt.get_compute_capability = lambda _: np.inf


def setUpModule():
config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -553,7 +559,7 @@ def test_specialization(self):
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=32,
BLOCK_SIZE_K=32,
# K_EXACTLY_DIVISIBLE_BY_BLOCK=False,
K_EXACTLY_DIVISIBLE_BY_BLOCK=False,
)
except TypeError:
pass # Error thrown as the mocked method's return value is invalid.
Expand Down
5 changes: 4 additions & 1 deletion tests/triton_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
import numpy as np
import triton
import triton.language as tl
from triton.language.extra.cuda import libdevice
try:
from triton.language.extra.cuda import libdevice
except ImportError:
from triton.language.extra.cpu import libdevice


@triton.jit
Expand Down