Skip to content

Commit

Permalink
[Codegen][Tuner] populate the default tuning specs for mi308x
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Jan 24, 2025
1 parent 8c5ff94 commit 216195f
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/plugins/target/ROCM/builtins/tuning/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ endif()
# Target archs for tuning specs. https://llvm.org/docs/AMDGPUUsage.html#processors
gpu_archs = [
"gfx942",
"mi308x",
]

tuning_spec_mlir_files = [
Expand Down
2 changes: 2 additions & 0 deletions compiler/plugins/target/ROCM/builtins/tuning/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ iree_c_embed_data(
iree_default_tuning_specs_amdgpu
SRCS
"iree_default_tuning_spec_gfx942.mlir"
"iree_default_tuning_spec_mi308x.mlir"
C_FILE_OUTPUT
"iree_default_tuning_specs_amdgpu.c"
H_FILE_OUTPUT
Expand All @@ -32,6 +33,7 @@ iree_lit_test_suite(
verify_default_tuning_specs_amdgpu
SRCS
"iree_default_tuning_spec_gfx942.mlir"
"iree_default_tuning_spec_mi308x.mlir"
TOOLS
iree-opt
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// RUN: iree-opt %s

// This is just an initial tuning spec for gfx942 and is not intended for
// production use.
// TODO(https://github.com/iree-org/iree/issues/19214): Add missing
// configurations to this spec.

module @iree_default_tuning_spec_mi308x attributes { transform.with_named_sequence, iree_codegen.tuning_spec_with_default_entrypoint } {

transform.named_sequence @apply_op_config(%op: !transform.any_op {transform.readonly},
%config: !transform.any_param {transform.readonly}) {
// transform.print %op {name="Apply on"} : !transform.any_op
transform.annotate %op "compilation_info" = %config : !transform.any_op, !transform.any_param
// Add a dummy unit attribute to be sure that the tuning spec applied.
// Otherwise it would be difficult to tell if the lowering config attribute
// comes from our tuning spec or if the compiler heuristic happened to produce
// the same config as this script.
transform.annotate %op "__tuning_spec_applied__" : !transform.any_op
transform.yield
}

transform.named_sequence @match_mmt_f16_f16_f32(%root: !transform.any_op {transform.readonly}) -> !transform.any_op {
transform.match.operation_name %root ["linalg.generic"] : !transform.any_op
// transform.print %root {name = "Generic"} : !transform.any_op
%ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root {
^bb0(%lhs: tensor<?x?xf16>, %rhs: tensor<?x?xf16>, %out: tensor<?x?xf32>):
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%lhs, %rhs : tensor<?x?xf16>, tensor<?x?xf16>) outs(%out : tensor<?x?xf32>) {
^bb0(%in: f16, %in_0: f16, %acc: f32):
%8 = arith.extf %in : f16 to f32
%9 = arith.extf %in_0 : f16 to f32
%10 = arith.mulf %8, %9 : f32
%11 = arith.addf %acc, %10 : f32
linalg.yield %11 : f32
} -> tensor<?x?xf32>
} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
transform.yield %root : !transform.any_op
}

transform.named_sequence @match_mmt_1920x1280x1280(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
%mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op
%lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value
%rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value
transform.iree.match.cast_compatible_type %lhs = tensor<1920x1280xf16> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<1280x1280xf16> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1],
mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
subgroup_m_count = 4, subgroup_n_count = 2,
reduction = [0, 0, 32],
workgroup = [128, 128, 0]}>,
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
workgroup_size = [128, 4, 1] subgroup_size = 64,
{gpu_pipeline_options = #iree_gpu.pipeline_options<prefetch_shared_memory = true>,
llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}
}>> -> !transform.any_param
transform.yield %matmul, %config : !transform.any_op, !transform.any_param
}

transform.named_sequence
@__kernel_config(%variant_op: !transform.any_op {transform.consumed}) -> !transform.any_op
attributes { iree_codegen.tuning_spec_entrypoint } {
%res = transform.foreach_match in %variant_op
@match_mmt_1920x1280x1280 -> @apply_op_config
: (!transform.any_op) -> !transform.any_op
transform.yield %res : !transform.any_op
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ iree_lit_test_suite(
name = "lit",
srcs = [
"spec_gfx942.mlir",
"spec_mi308x.mlir",
],
cfg = "//compiler:lit.cfg.py",
tools = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ iree_lit_test_suite(
lit
SRCS
"spec_gfx942.mlir"
"spec_mi308x.mlir"
TOOLS
FileCheck
iree-opt
Expand Down
46 changes: 46 additions & 0 deletions compiler/plugins/target/ROCM/builtins/tuning/test/spec_mi308x.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=mi308x@hip \
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-configure-target-executable-variants{target=rocm})))" \
// RUN: --iree-codegen-enable-default-tuning-specs \
// RUN: --iree-codegen-notify-transform-strategy-application \
// RUN: --verify-diagnostics %s | FileCheck %s

// Check that the default configuration for mmt_1920x1280x1280_f16_f16_f3
// applies to the `linalg.matmul_transpose_b` below.

// CHECK-LABEL: func.func @mmt_1920x1280x1280_f16_f16_f3
// CHECK: linalg.generic
// CHECK-SAME: __tuning_spec_applied__

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
hal.executable public @main {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export public @matmul_transpose_b ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
// expected-remark@+1 {{Applied transform configuration strategy @iree_default_tuning_spec_mi308x::@__kernel_config}}
func.func @mmt_1920x1280x1280_f16_f16_f32() {
%cst = arith.constant 0.000000e+00 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1920x1280xf16>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1280x1280xf16>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1920x1280xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1920, 1280], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1920x1280xf16>> -> tensor<1920x1280xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1280, 1280], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1280x1280xf16>> -> tensor<1280x1280xf16>
%5 = tensor.empty() : tensor<1920x1280xf32>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<1920x1280xf32>) -> tensor<1920x1280xf32>
%7 = linalg.matmul_transpose_b
ins(%3, %4 : tensor<1920x1280xf16>, tensor<1280x1280xf16>)
outs(%6 : tensor<1920x1280xf32>) -> tensor<1920x1280xf32>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [1920, 1280], strides = [1, 1] : tensor<1920x1280xf32> -> !flow.dispatch.tensor<writeonly:tensor<1920x1280xf32>>
return
}
}
}
}

0 comments on commit 216195f

Please sign in to comment.