Skip to content

Commit

Permalink
[TOSA] Switch to tablegen pass generation (#18227)
Browse files Browse the repository at this point in the history
This switches the pass generation definition to tablegen. The cleanup
includes switching passes to follow the `create*Pass` naming convention
and introduces anonymous namespaces.
  • Loading branch information
marbre authored Aug 16, 2024
1 parent 878a99b commit 551cd54
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 85 deletions.
1 change: 0 additions & 1 deletion compiler/plugins/input/TOSA/InputConversion/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ iree_gentbl_cc_library(
iree_compiler_cc_library(
name = "PassHeaders",
hdrs = [
"PassDetail.h",
"Passes.h",
"Passes.h.inc",
],
Expand Down
1 change: 0 additions & 1 deletion compiler/plugins/input/TOSA/InputConversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ iree_cc_library(
NAME
PassHeaders
HDRS
"PassDetail.h"
"Passes.h"
"Passes.h.inc"
DEPS
Expand Down
16 changes: 8 additions & 8 deletions compiler/plugins/input/TOSA/InputConversion/Converti48Toi64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
Expand All @@ -16,7 +14,13 @@ using namespace mlir;

namespace mlir::iree_compiler {

class Converti48Toi64Pass : public Converti48Toi64Base<Converti48Toi64Pass> {
#define GEN_PASS_DEF_CONVERTI48TOI64PASS
#include "compiler/plugins/input/TOSA/InputConversion/Passes.h.inc"

namespace {

class Converti48Toi64Pass final
: public impl::Converti48Toi64PassBase<Converti48Toi64Pass> {
public:
explicit Converti48Toi64Pass() = default;
void runOnOperation() override;
Expand Down Expand Up @@ -174,9 +178,5 @@ void Converti48Toi64Pass::runOnOperation() {
}
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConverti48Toi64() {
return std::make_unique<Converti48Toi64Pass>();
}

} // namespace
} // namespace mlir::iree_compiler
25 changes: 0 additions & 25 deletions compiler/plugins/input/TOSA/InputConversion/PassDetail.h

This file was deleted.

6 changes: 3 additions & 3 deletions compiler/plugins/input/TOSA/InputConversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ void buildTOSAInputConversionPassPipeline(OpPassManager &passManager) {
passManager.addNestedPass<func::FuncOp>(tosa::createTosaToArith());
passManager.addNestedPass<func::FuncOp>(tosa::createTosaToTensor());
passManager.addNestedPass<func::FuncOp>(
iree_compiler::createTosaToLinalgExt());
iree_compiler::createTosaToLinalgExtPass());
passManager.addNestedPass<func::FuncOp>(mlir::createCanonicalizerPass());

TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
tosaToLinalgNamedOptions.preferConv2DKernelLayoutHWCF = true;
tosa::addTosaToLinalgPasses(passManager, TosaToLinalgOptions(),
tosaToLinalgNamedOptions);
passManager.addNestedPass<func::FuncOp>(
iree_compiler::createConverti48Toi64());
iree_compiler::createConverti48Toi64Pass());

// Sometimes we generate more TOSA operations during the lowering to linalg.
passManager.addNestedPass<func::FuncOp>(tosa::createTosaToArith());
Expand All @@ -74,7 +74,7 @@ void buildTOSAInputConversionPassPipeline(OpPassManager &passManager) {
//----------------------------------------------------------------------------
// Entry dialect cleanup
//----------------------------------------------------------------------------
passManager.addPass(createVerifyCompilerTOSAInputLegality());
passManager.addPass(createVerifyCompilerTOSAInputLegalityPass());
}

namespace {
Expand Down
20 changes: 3 additions & 17 deletions compiler/plugins/input/TOSA/InputConversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,13 @@ void registerTOSAConversionPassPipeline();
// Set of patterns for materializing TOSA operations to linalg_ext.
void populateTosaToLinalgExtPatterns(RewritePatternSet *patterns);

// Converts i48 to i64.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConverti48Toi64();

// Strips the signed/unsigned portion off of tensors.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createStripSignednessPass();

// Converts TOSA operations to linalg_ext.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createTosaToLinalgExt();

// Verifies that a module only contains IR structures that are supported by the
// core compiler.
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyCompilerTOSAInputLegality();

//===----------------------------------------------------------------------===//
// Register all Passes
//===----------------------------------------------------------------------===//

#define GEN_PASS_DECL
#include "compiler/plugins/input/TOSA/InputConversion/Passes.h.inc" // IWYU pragma: export

void registerTOSAConversionPasses();

} // namespace mlir::iree_compiler
Expand Down
12 changes: 4 additions & 8 deletions compiler/plugins/input/TOSA/InputConversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,19 @@

include "mlir/Pass/PassBase.td"

def Converti48Toi64 :
def Converti48Toi64Pass :
InterfacePass<"iree-tosa-convert-i48-to-i64", "mlir::FunctionOpInterface"> {
let summary = "Converts all i48s to i64s";
let constructor = "mlir::iree_compiler::createConverti48Toi64()";
}

def StripSignedness :
def StripSignednessPass :
InterfacePass<"iree-tosa-strip-signedness", "mlir::FunctionOpInterface"> {
let summary = "Legalizes ui tensors constants to uis";
let constructor = "mlir::iree_compiler::createStripSignednessPass()";
}

def TosaToLinalgExt :
def TosaToLinalgExtPass :
InterfacePass<"iree-tosa-to-linalg-ext", "mlir::FunctionOpInterface"> {
let summary = "Convert TOSA operations to their equivalent linalg-ext operations.";
let constructor = "mlir::iree_compiler::createTosaToLinalgExt()";
let dependentDialects = [
"arith::ArithDialect",
"linalg::LinalgDialect",
Expand All @@ -33,10 +30,9 @@ def TosaToLinalgExt :
];
}

def VerifyCompilerTOSAInputLegality :
def VerifyCompilerTOSAInputLegalityPass :
Pass<"iree-tosa-verify-compiler-input-legality", "ModuleOp"> {
let summary = "Verifies that only supported IR constructs are passed to the compiler.";
let constructor = "mlir::iree_compiler::createVerifyCompilerTOSAInputLegality()";
}

#endif // IREE_COMPILER_PLUGINS_INPUT_TOSA_INPUTCONVERSION_PASSES
12 changes: 5 additions & 7 deletions compiler/plugins/input/TOSA/InputConversion/StripSignedness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h"
#include "compiler/plugins/input/TOSA/InputConversion/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_STRIPSIGNEDNESSPASS
#include "compiler/plugins/input/TOSA/InputConversion/Passes.h.inc"

namespace {

class StripSignednessPass : public StripSignednessBase<StripSignednessPass> {
class StripSignednessPass final
: public impl::StripSignednessPassBase<StripSignednessPass> {
public:
explicit StripSignednessPass() {}
void runOnOperation() override;
Expand Down Expand Up @@ -125,9 +128,4 @@ void StripSignednessPass::runOnOperation() {

} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createStripSignednessPass() {
return std::make_unique<StripSignednessPass>();
}

} // namespace mlir::iree_compiler
17 changes: 10 additions & 7 deletions compiler/plugins/input/TOSA/InputConversion/TosaToLinalgExt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h"
#include "compiler/plugins/input/TOSA/InputConversion/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
Expand All @@ -21,6 +20,11 @@ using namespace mlir::tosa;

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_TOSATOLINALGEXTPASS
#include "compiler/plugins/input/TOSA/InputConversion/Passes.h.inc"

namespace {

// Converts tosa.scatter to the iree_linalg_ext.scatter operation. As the
// LinalgExt version is not batched therefore we materialize the batch index
// for each update.
Expand Down Expand Up @@ -145,7 +149,9 @@ class ScatterConversion : public OpRewritePattern<tosa::ScatterOp> {
}
};

struct TosaToLinalgExtPass : public TosaToLinalgExtBase<TosaToLinalgExtPass> {
class TosaToLinalgExtPass final
: public impl::TosaToLinalgExtPassBase<TosaToLinalgExtPass> {
public:
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
Expand All @@ -159,13 +165,10 @@ struct TosaToLinalgExtPass : public TosaToLinalgExtBase<TosaToLinalgExtPass> {
}
};

} // namespace

void populateTosaToLinalgExtPatterns(RewritePatternSet *patterns) {
patterns->add<ScatterConversion>(patterns->getContext());
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createTosaToLinalgExt() {
return std::make_unique<TosaToLinalgExtPass>();
}

} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "compiler/plugins/input/TOSA/InputConversion/PassDetail.h"
#include "compiler/plugins/input/TOSA/InputConversion/Passes.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h"
Expand All @@ -13,9 +12,15 @@

namespace mlir::iree_compiler {

struct VerifyCompilerTOSAInputLegalityPass
: public VerifyCompilerTOSAInputLegalityBase<
#define GEN_PASS_DEF_VERIFYCOMPILERTOSAINPUTLEGALITYPASS
#include "compiler/plugins/input/TOSA/InputConversion/Passes.h.inc"

namespace {

class VerifyCompilerTOSAInputLegalityPass final
: public impl::VerifyCompilerTOSAInputLegalityPassBase<
VerifyCompilerTOSAInputLegalityPass> {
public:
void runOnOperation() override {
auto *context = &getContext();
ConversionTarget conversionTarget(*context);
Expand Down Expand Up @@ -63,9 +68,5 @@ struct VerifyCompilerTOSAInputLegalityPass
}
};

std::unique_ptr<OperationPass<ModuleOp>>
createVerifyCompilerTOSAInputLegality() {
return std::make_unique<VerifyCompilerTOSAInputLegalityPass>();
}

} // namespace
} // namespace mlir::iree_compiler

0 comments on commit 551cd54

Please sign in to comment.