From e6d7428fda4997a8fd47dc3c419b34cadd1223c7 Mon Sep 17 00:00:00 2001 From: Muhammad Asif Manzoor Date: Wed, 17 Jul 2024 15:33:27 +0000 Subject: [PATCH] Add greater or equal op end to end --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 7 +++++++ include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 7 +++++++ include/ttmlir/Target/TTNN/program.fbs | 1 + lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp | 2 ++ lib/Dialect/TTIR/Transforms/Passes.cpp | 8 ++++++++ lib/Dialect/TTNN/Transforms/Passes.cpp | 1 + .../TTNN/Transforms/TTNNToSerializedBinary.cpp | 5 +++++ runtime/lib/ttnn/program.cpp | 8 ++++++++ test/ttmlir/Dialect/TTNN/simple_ge.mlir | 15 +++++++++++++++ 9 files changed, 54 insertions(+) create mode 100644 test/ttmlir/Dialect/TTNN/simple_ge.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index a8efff692..c9f1b34a7 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -153,6 +153,13 @@ def TTIR_MultiplyOp : TTIR_ElementwiseBinaryOp<"multiply"> { }]; } +def TTIR_GreaterEqualOp : TTIR_ElementwiseBinaryOp<"ge"> { + let summary = "Eltwise greater than or equal to."; + let description = [{ + Eltwise greater than or equal to operation. + }]; +} + class TTIR_ReductionOp traits = []> : TTIR_DPSOp { let summary = "Reduction op."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 840f9b64f..3fc48f0a0 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -92,6 +92,13 @@ def TTNN_MultiplyOp : TTNN_ElementwiseBinaryOp<"multiply"> { }]; } +def TTNN_GreaterEqualOp : TTNN_ElementwiseBinaryOp<"ge"> { + let summary = "Eltwise greater than or equal to."; + let description = [{ + Eltwise greater than or equal to operation. + }]; +} + class TTNN_ReductionOp traits = []> : TTNN_NamedDPSOp { let summary = "Reduction op."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 581fafc14..a2251c1cb 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -29,6 +29,7 @@ enum EltwiseOpType: uint32 { Multiply = 1, Subtract = 2, Relu = 3, + GreaterEqual = 4, } table EltwiseOp { diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index a140e8d1c..491ab7895 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -52,6 +52,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, patterns.add>(typeConverter, ctx); patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, + ctx); patterns.add>(typeConverter, ctx); patterns.add>(typeConverter, ctx); patterns.add>(typeConverter, diff --git a/lib/Dialect/TTIR/Transforms/Passes.cpp b/lib/Dialect/TTIR/Transforms/Passes.cpp index 7fce43ab2..662b018b6 100644 --- a/lib/Dialect/TTIR/Transforms/Passes.cpp +++ b/lib/Dialect/TTIR/Transforms/Passes.cpp @@ -75,6 +75,9 @@ class ConvertTosaToTTIR TosaToTTIREltwiseBinaryRewriter, TosaToTTIREltwiseBinaryRewriter, + TosaToTTIREltwiseBinaryRewriter>( &getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); @@ -115,6 +118,9 @@ class TTIRNamedToKernelRewriter : public OpRewritePattern { } else if constexpr (std::is_same::value) { kernelName = "subtract"; kernelKind = "eltwise"; + } else if constexpr (std::is_same::value) { + kernelName = "ge"; + kernelKind = "eltwise"; } else if constexpr (std::is_same::value) { kernelName = "relu"; kernelKind = "eltwise"; @@ -272,6 +278,7 @@ class TTIRGeneric : public impl::TTIRGenericBase { TTIRNamedToKernelRewriter, TTIRNamedToKernelRewriter, TTIRNamedToKernelRewriter, + TTIRNamedToKernelRewriter, TTIRNamedToKernelRewriter>(&getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { @@ -567,6 +574,7 @@ class TTIRLayout : public impl::TTIRLayoutBase { TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, + TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, TTIRLayoutFuncReturnRewriter>( diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index 1102c8106..358a025f1 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -159,6 +159,7 @@ class ConvertTTIRToTTNN .add, TTIRToTTNNOpRewriter, TTIRToTTNNOpRewriter, + TTIRToTTNNOpRewriter, TTIRToTTNNOpRewriter, TTIRToTTNNBinaryOpRewriter, TTIRToTTNNReductionOpRewriter, diff --git a/lib/Dialect/TTNN/Transforms/TTNNToSerializedBinary.cpp b/lib/Dialect/TTNN/Transforms/TTNNToSerializedBinary.cpp index bb96925d1..a1bee9b42 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNToSerializedBinary.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNToSerializedBinary.cpp @@ -113,6 +113,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { type = ::tt::target::ttnn::EltwiseOpType::Multiply; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Subtract; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::GreaterEqual; } else if constexpr (std::is_same_v) { type = ::tt::target::ttnn::EltwiseOpType::Relu; } else { @@ -191,6 +193,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, return createOperation(cache, createEltwiseOp(cache, subtractOp), debugString); } + if (auto geOp = dyn_cast(op); geOp) { + return createOperation(cache, createEltwiseOp(cache, geOp), debugString); + } if (auto reluOp = dyn_cast(op); reluOp) { return createOperation(cache, createEltwiseOp(cache, reluOp), debugString); } diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 9d5572ab3..be0d13223 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -92,6 +92,14 @@ run(::tt::target::ttnn::EltwiseOp const *op, ::ttnn::Device &device, liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); break; } + case ::tt::target::ttnn::EltwiseOpType::GreaterEqual: { + assert(op->ins()->size() == 2 && "Unsupported number of inputs"); + ::ttnn::Tensor &lhs = *liveTensors.at(op->ins()->Get(0)->global_id()); + ::ttnn::Tensor &rhs = *liveTensors.at(op->ins()->Get(1)->global_id()); + tensorPool.push_back(::ttnn::ge(lhs, rhs)); + liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + break; + } default: throw std::runtime_error("Unsupported elementwise operation type"); } diff --git a/test/ttmlir/Dialect/TTNN/simple_ge.mlir b/test/ttmlir/Dialect/TTNN/simple_ge.mlir new file mode 100644 index 000000000..e2d5e07cf --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/simple_ge.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {tt.system_desc = #tt.system_desc<[{arch = , grid = <8x8>, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [], [<0, 0, 0, 0>]>} { + func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] + %0 = tensor.empty() : tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.ge"[[C:.*]] + %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] + // CHECK: "ttnn.close_device"[[C:.*]] + return %1 : tensor<64x128xf32> + } +}