Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen][Tuner] Add support for per-sku tuning spec #19762

Merged
merged 9 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,23 @@
// RUN: --iree-codegen-notify-transform-strategy-application \
// RUN: --verify-diagnostics %s | FileCheck %s

// RUN: iree-opt --split-input-file --iree-gpu-test-target=mi300x@hip \
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-configure-target-executable-variants{target=rocm})))" \
// RUN: --iree-codegen-enable-default-tuning-specs \
// RUN: --iree-codegen-notify-transform-strategy-application \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed when this flag was added, but it should be named --iree-codegen-test and/or be hidden or something. Not for this PR though.

// RUN: --verify-diagnostics %s | FileCheck %s --check-prefix=MI300X

// Check that the default configuration for mmt_2048x1280x5120_f16_f16_f32
// applies to the `linalg.matmul_transpose_b` below.

// CHECK-LABEL: func.func @mmt_2048x1280x5120_f16_f16_f32
// CHECK: linalg.generic
// CHECK-SAME: __tuning_spec_applied__

// MI300X-LABEL: func.func @mmt_2048x1280x5120_f16_f16_f32
// MI300X: linalg.generic
// MI300X-SAME: __tuning_spec_applied__

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647],
// MI300X: chip = <wgp_count = 304>>
// MI300A: chip = <wgp_count = 228>>
// MI308X: chip = <wgp_count = 80>>
// MI300X: chip = <wgp_count = 304, sku = "mi300x">>
// MI300A: chip = <wgp_count = 228, sku = "mi300a">>
// MI308X: chip = <wgp_count = 80, sku = "mi308x">>

// GFX941: target = #iree_gpu.target<arch = "gfx941",
// GFX941-SAME: features = "+sramecc,-xnack"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,18 @@ getUserTuningSpec(ModuleOp module, IREE::Codegen::IREECodegenDialect &dialect) {
return *maybeTransformLibrary;
}

static std::optional<StringRef> fetchDefaultTuningSpec(StringRef identifier) {
std::string tuningSpecName =
llvm::formatv("iree_default_tuning_spec_{}.mlir", identifier);
std::optional<StringRef> tuningSpecSource;

EmbeddedDataDirectory::withGlobal([&](EmbeddedDataDirectory &dir) {
tuningSpecSource = dir.getFile(tuningSpecName);
});

return tuningSpecSource;
}

static FailureOr<ModuleOp>
getDefaultTuningSpec(ModuleOp module,
IREE::Codegen::IREECodegenDialect &dialect) {
Expand All @@ -123,14 +135,29 @@ getDefaultTuningSpec(ModuleOp module,
return failure();
}

// Try to look up the default tuning spec for this architecture, if any.
StringRef arch = gpuTarget.getArch();
std::string defaultTuningSpecName =
llvm::formatv("iree_default_tuning_spec_{}.mlir", arch);
std::optional<StringRef> sku;
if (IREE::GPU::TargetChipAttr chip = gpuTarget.getChip()) {
if (StringAttr chipSku = chip.getSku()) {
sku = chipSku.getValue();
}
}

std::string defaultTuningSpecName;
std::optional<StringRef> defaultTuningSpecSource;
EmbeddedDataDirectory::withGlobal([&](EmbeddedDataDirectory &dir) {
defaultTuningSpecSource = dir.getFile(defaultTuningSpecName);
});
if (sku) {
// GPUs with the same ISA may have different hardware characteristics such
// as the number of workgroup processors and power limits, Look up
// SKU-specific tuning spec for optimal performance.
defaultTuningSpecSource = fetchDefaultTuningSpec(*sku);
}

if (!defaultTuningSpecSource) {
// If SKU-specific spec is not found, fall back to the default
// architecture-based tuning spec to ensure broader compatibility.
StringRef arch = gpuTarget.getArch();
defaultTuningSpecSource = fetchDefaultTuningSpec(arch);
}

if (!defaultTuningSpecSource) {
// Not all architectures are expected to provide default tuning specs, so
// this shouldn't be considered a hard error (but that's up to the caller).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ def IREEGPU_TargetChipAttr : AttrDef<IREEGPU_Dialect, "TargetChip"> {
let parameters = (ins
"uint32_t":$wgp_count,

// An optional SKU identifier to distinguish different models.
OptionalParameter<"StringAttr">:$sku,
// An optional extra dict
// This field allows to inject more features/limits not supported in the
// above list for better flexibility.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ struct WgpDetails {
// Chip level feature/limit details
struct ChipDetails {
uint32_t wgpCount;
std::optional<StringRef> sku;
};

// Full target details
Expand Down Expand Up @@ -116,9 +117,13 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch,
DictionaryAttr{});

TargetChipAttr targetChip;
if (details.chip)
targetChip =
TargetChipAttr::get(context, details.chip->wgpCount, DictionaryAttr{});
if (details.chip) {
auto skuAttr = details.chip->sku
? StringAttr::get(context, *details.chip->sku)
: StringAttr{};
targetChip = TargetChipAttr::get(context, details.chip->wgpCount, skuAttr,
DictionaryAttr{});
}

return TargetAttr::get(context, arch, features, targetWgp, targetChip);
}
Expand Down Expand Up @@ -279,28 +284,27 @@ std::optional<TargetDetails> getAMDGPUTargetDetails(StringRef target) {

// "AMD Instinct MI300 Series Product Offerings" in Page 23 of
// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/white-papers/amd-cdna-3-white-paper.pdf
static const ChipDetails mi300xChip = {304};
static const ChipDetails mi300aChip = {228};
static const ChipDetails mi308xChip = {80};
static const ChipDetails mi300xChip = {304, "mi300x"};
static const ChipDetails mi300aChip = {228, "mi300a"};
static const ChipDetails mi308xChip = {80, "mi308x"};

// "AMD Instinct MI200 Series Accelerator Product Offerings" in Page 14 of
// https://www.amd.com/content/dam/amd/en/documents/instinct-business-docs/white-papers/amd-cdna2-white-paper.pdf
static const ChipDetails mi250xChip = {220};
static const ChipDetails mi250Chip = {208};
static const ChipDetails mi210Chip = {104};
static const ChipDetails mi250xChip = {220, "mi250x"};
static const ChipDetails mi250Chip = {208, "mi250"};
static const ChipDetails mi210Chip = {104, "mi210"};

// "AMD CDNA Architecture Compute Units" in Page 5 of
// https://www.amd.com/content/dam/amd/en/documents/instinct-business-docs/white-papers/amd-cdna-white-paper.pdf
static const ChipDetails mi100Chip = {120};
static const ChipDetails mi100Chip = {120, "mi100"};

static const ChipDetails rx7900xtxChip = {96};
static const ChipDetails rx7900xtChip = {84};
static const ChipDetails rx7800xtChip = {60};
static const ChipDetails rx7700xtChip = {54};
static const ChipDetails rx7900xtxChip = {96, "rx7900xtx"};
static const ChipDetails rx7900xtChip = {84, "rx7900xt"};
static const ChipDetails rx7800xtChip = {60, "rx7800xt"};
static const ChipDetails rx7700xtChip = {54, "rx7700xt"};

// See https://llvm.org/docs/AMDGPUUsage.html#processors for gfxN to
// cdnaN/rdnaN mapping.

return llvm::StringSwitch<std::optional<TargetDetails>>(target.lower())
.Case("mi300x", TargetDetails{cdna3Wgp, &mi300xChip})
.Case("mi300a", TargetDetails{cdna3Wgp, &mi300aChip})
Expand Down
Loading