diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 25e8e799e..8d41cced6 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -298,7 +298,7 @@ def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> { def AnyRankedTensorOrMemRef: AnyTypeOf<[AnyRankedTensor, AnyNon0RankedMemRef]>; -def TTIR_KernelOp : TTIR_Op<"kernel", [DestinationStyleOpInterface, AttrSizedOperandSegments]> { +def TTIR_KernelOp : TTIR_DPSOp<"kernel", [AttrSizedOperandSegments]> { let summary = "Kernel call."; let description = [{ A generic kernel call operation. This operation is used to pattern match by some consuming backend. @@ -307,12 +307,9 @@ def TTIR_KernelOp : TTIR_Op<"kernel", [DestinationStyleOpInterface, AttrSizedOpe let arguments = (ins FlatSymbolRefAttr:$op, FlatSymbolRefAttr:$kind, Variadic:$inputs, - Variadic:$outputs); + Variadic:$outputs, + TT_OperandConstraintArrayAttr:$operand_constraints); let results = (outs Variadic:$results); - - let extraClassDeclaration = [{ - MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); } - }]; } def TTIR_YieldOp : TTIR_Op<"yield", [Pure, ReturnLike, Terminator]> { diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index 8c1391af6..a42075c4a 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -34,6 +34,13 @@ def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> { let description = [{ Transition between different tensor layouts. }]; + + let options = [ + Option<"initMemorySpace", "init-memory-space", + "::mlir::tt::MemorySpace", + /*default=*/"::mlir::tt::MemorySpace::System", + "Set the initial memory space for tensors to start in">, + ]; } def TTIRAllocate: Pass<"ttir-allocate", "::mlir::ModuleOp"> { diff --git a/lib/Dialect/TTIR/Transforms/Passes.cpp b/lib/Dialect/TTIR/Transforms/Passes.cpp index eb4ab8c1c..8f37f23f5 100644 --- a/lib/Dialect/TTIR/Transforms/Passes.cpp +++ b/lib/Dialect/TTIR/Transforms/Passes.cpp @@ -98,7 +98,7 @@ class TTIRNamedToKernelRewriter : public OpRewritePattern { auto kernel = rewriter.create( op.getLoc(), op.getResultTypes(), kernelName, kernelKind, - op.getInputs(), op.getOutputs()); + op.getInputs(), op.getOutputs(), op.getOperandConstraints()); rewriter.replaceOp(op, kernel); @@ -343,14 +343,15 @@ inline MemorySpace uppermostMemorySpace(OperandConstraint operandConstraint) { class TTIRLayoutTensorTypeConverter : public TypeConverter { public: - TTIRLayoutTensorTypeConverter(MLIRContext *ctx) { + TTIRLayoutTensorTypeConverter(MLIRContext *ctx, MemorySpace initMemorySpace) { addConversion([](Type type) { return type; }); - addConversion([ctx](RankedTensorType type) -> Type { + addConversion([ctx, initMemorySpace](RankedTensorType type) -> Type { auto layout = type.getEncoding(); if (layout) { return type; } - auto newLayout = LayoutAttr::get(ctx, type); + // Default to initMemorySpace, the optimizer might decide otherwise + auto newLayout = LayoutAttr::get(ctx, type, initMemorySpace); return RankedTensorType::get(type.getShape(), type.getElementType(), newLayout); }); @@ -415,13 +416,12 @@ class TTIRLayoutTensorTypeRewriter : public RewritePattern { const TypeConverter *converter; }; -static std::optional -createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, - OperandConstraint operandConstraint) { +static std::optional createToLayoutOp(PatternRewriter &rewriter, + Location loc, Value input, + MemorySpace desiredMemorySpace) { auto ty = mlir::cast(input.getType()); auto currLayout = mlir::cast(ty.getEncoding()); auto currMemorySpace = currLayout.getMemorySpace(); - auto desiredMemorySpace = uppermostMemorySpace(operandConstraint); if (currMemorySpace == desiredMemorySpace) { return std::nullopt; } @@ -440,27 +440,38 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, ->getResult(0); } -template -class TTIRLayoutOperandsRewriter : public OpRewritePattern { +static std::optional +createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, + OperandConstraint operandConstraint) { + auto desiredMemorySpace = uppermostMemorySpace(operandConstraint); + return createToLayoutOp(rewriter, loc, input, desiredMemorySpace); +} + +class TTIRLayoutDPSOperandsRewriter + : public OpInterfaceRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpInterfaceRewritePattern< + DestinationStyleOpInterface>::OpInterfaceRewritePattern; - LogicalResult matchAndRewrite(TTIROpTy op, + LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const final { + if (mlir::isa(op.getOperation())) { + // Skip the ToLayoutOp itself + return failure(); + } + assert(op->template hasTrait()); - auto dpsInterface = cast(op.getOperation()); bool modified = false; for (auto &operand : op->getOpOperands()) { - bool isResult = dpsInterface.isDpsInit(&operand); + bool isResult = op.isDpsInit(&operand); auto encoding = mlir::cast(operand.get().getType()).getEncoding(); - if (not encoding) { - return failure(); // Hasn't been type converted yet - } + assert(encoding); auto operandConstraint = mlir::cast( - op.getOperandConstraints()[operand.getOperandNumber()]) + mlir::cast(op.getOperation()) + .getOperandConstraints()[operand.getOperandNumber()]) .getValue(); auto desiredLayout = createToLayoutOp(rewriter, op.getLoc(), operand.get(), operandConstraint); @@ -495,14 +506,18 @@ class TTIRLayoutOperandsRewriter : public OpRewritePattern { class TTIRLayoutFuncReturnRewriter : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + TTIRLayoutFuncReturnRewriter(MLIRContext *ctx, MemorySpace initMemorySpace) + : OpRewritePattern(ctx), + initMemorySpace(initMemorySpace) {} LogicalResult matchAndRewrite(mlir::func::ReturnOp op, PatternRewriter &rewriter) const final { bool modified = false; for (auto &operand : op->getOpOperands()) { + // Leave the return values in initMemorySpace, optimizer might decide + // otherwise if (auto layout = createToLayoutOp(rewriter, op.getLoc(), operand.get(), - OperandConstraint::System); + initMemorySpace); layout) { rewriter.modifyOpInPlace( op, [&]() { op.setOperand(operand.getOperandNumber(), *layout); }); @@ -511,6 +526,9 @@ class TTIRLayoutFuncReturnRewriter } return modified ? success() : failure(); } + +private: + MemorySpace initMemorySpace; }; class TTIRLayout : public impl::TTIRLayoutBase { @@ -519,7 +537,8 @@ class TTIRLayout : public impl::TTIRLayoutBase { void runOnOperation() final { { - TTIRLayoutTensorTypeConverter typeConverter(&getContext()); + TTIRLayoutTensorTypeConverter typeConverter(&getContext(), + initMemorySpace); RewritePatternSet patterns(&getContext()); patterns.add(typeConverter, &getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); @@ -530,18 +549,9 @@ class TTIRLayout : public impl::TTIRLayoutBase { } { RewritePatternSet patterns(&getContext()); - patterns.add< - TTIRLayoutOperandsRewriter, - TTIRLayoutOperandsRewriter, - TTIRLayoutOperandsRewriter, - TTIRLayoutOperandsRewriter, - TTIRLayoutOperandsRewriter, - TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, - TTIRLayoutOperandsRewriter, - TTIRLayoutOperandsRewriter, - TTIRLayoutOperandsRewriter, - TTIRLayoutOperandsRewriter, TTIRLayoutFuncReturnRewriter>( - &getContext()); + patterns.add(&getContext()); + patterns.add(&getContext(), + initMemorySpace); FrozenRewritePatternSet patternSet(std::move(patterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { signalPassFailure(); diff --git a/lib/Dialect/TTMetal/Transforms/Passes.cpp b/lib/Dialect/TTMetal/Transforms/Passes.cpp index 6fd7474de..937cc98d2 100644 --- a/lib/Dialect/TTMetal/Transforms/Passes.cpp +++ b/lib/Dialect/TTMetal/Transforms/Passes.cpp @@ -255,6 +255,8 @@ void createTTIRToTTMetalBackendPipeline(OpPassManager &pm) { pm.addPass(mlir::tt::ttir::createTTIRLoadSystemDesc()); pm.addPass(mlir::tt::ttir::createTTIRImplicitDevice()); pm.addPass(mlir::tt::ttir::createTTIRGeneric()); + mlir::tt::ttir::TTIRLayoutOptions layoutOptions; + layoutOptions.initMemorySpace = mlir::tt::MemorySpace::DeviceL1; pm.addPass(mlir::tt::ttir::createTTIRLayout()); pm.addPass(mlir::tt::ttir::createTTIRGenericRegionOperandsToMemref()); pm.addPass(mlir::tt::ttir::createTTIRAllocate()); diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index 2aec4092f..9a84025e6 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -17,7 +17,9 @@ void createTTIRToTTNNBackendPipeline( OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) { pm.addPass(mlir::tt::ttir::createTTIRLoadSystemDesc()); pm.addPass(mlir::tt::ttir::createTTIRImplicitDevice()); - pm.addPass(mlir::tt::ttir::createTTIRLayout()); + mlir::tt::ttir::TTIRLayoutOptions layoutOptions; + layoutOptions.initMemorySpace = mlir::tt::MemorySpace::System; + pm.addPass(mlir::tt::ttir::createTTIRLayout(layoutOptions)); if (options.gridSetPassEnabled) { ttir::TTIRGridSetOptions gridSetOptions; diff --git a/test/ttmlir/Dialect/TTIR/test_allocate.mlir b/test/ttmlir/Dialect/TTIR/test_allocate.mlir index 0968a67c7..3708d115f 100644 --- a/test/ttmlir/Dialect/TTIR/test_allocate.mlir +++ b/test/ttmlir/Dialect/TTIR/test_allocate.mlir @@ -1,12 +1,13 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-layout --ttir-allocate %s | FileCheck %s +// RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-allocate %s | FileCheck %s #any_device = #tt.operand_constraint -module attributes {} { - func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { +#l1_ = #tt.memory_space +#layout = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> +module attributes {tt.device = #tt.device<#tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, [0]>, tt.system_desc = #tt.system_desc<[{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [], [<0, 0, 0, 0>]>} { + func.func @forward(%arg0: tensor<64x128xf32, #layout>, %arg1: tensor<64x128xf32, #layout>) -> tensor<64x128xf32, #layout> { // CHECK: %[[C:.*]] = "ttir.alloc"[[C:.*]] // CHECK-NOT: %[[C:.*]] = tensor.empty() : tensor<64x128xf32> - %0 = tensor.empty() : tensor<64x128xf32> - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: "ttir.dealloc"[[C:.*]] - return %1 : tensor<64x128xf32> + %0 = tensor.empty() : tensor<64x128xf32, #layout> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout>) -> tensor<64x128xf32, #layout> + return %1 : tensor<64x128xf32, #layout> } } diff --git a/test/ttmlir/Dialect/TTIR/test_grid_set.mlir b/test/ttmlir/Dialect/TTIR/test_grid_set.mlir index 9c867852e..bf6eae61e 100644 --- a/test/ttmlir/Dialect/TTIR/test_grid_set.mlir +++ b/test/ttmlir/Dialect/TTIR/test_grid_set.mlir @@ -3,10 +3,8 @@ module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: #layout2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #l1_>> - // CHECK: %[[C:.*]] = "ttir.to_layout"[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.to_layout"[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.multiply"[[C:.*]] -> tensor<64x128xf32, #layout2> + // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #l1_>> + // CHECK: %[[C:.*]] = "ttir.multiply"[[C:.*]] -> tensor<64x128xf32, #[[LAYOUT_1]]> %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Dialect/TTIR/test_layout.mlir b/test/ttmlir/Dialect/TTIR/test_layout.mlir index e59fea576..5a5f11426 100644 --- a/test/ttmlir/Dialect/TTIR/test_layout.mlir +++ b/test/ttmlir/Dialect/TTIR/test_layout.mlir @@ -2,10 +2,8 @@ #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<8x64x128xf32>, %arg1: tensor<8x64x128xf32>) -> tensor<8x64x128xf32> { + // CHECK: %[[C:.*]] = tensor.empty() : tensor<8x64x128xf32, #layout> %0 = tensor.empty() : tensor<8x64x128xf32> - // CHECK: %[[C:.*]] = "ttir.to_layout"[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.to_layout"[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.multiply"[[C:.*]] %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8x64x128xf32>, tensor<8x64x128xf32>, tensor<8x64x128xf32>) -> tensor<8x64x128xf32> return %1 : tensor<8x64x128xf32> } diff --git a/test/ttmlir/Dialect/TTMetal/simple_eltwise.mlir b/test/ttmlir/Dialect/TTMetal/simple_eltwise.mlir index cd3bcea00..689be9072 100644 --- a/test/ttmlir/Dialect/TTMetal/simple_eltwise.mlir +++ b/test/ttmlir/Dialect/TTMetal/simple_eltwise.mlir @@ -3,22 +3,16 @@ func.func @multiply(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] - // CHECK: %[[C:.*]] = "ttmetal.host_write"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: "ttmetal.dealloc"[[C:.*]] - // CHECK: %[[C:.*]] = "ttmetal.host_read"[[C:.*]] return %1 : tensor<64x128xf32> } func.func @add(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] - // CHECK: %[[C:.*]] = "ttmetal.host_write"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: "ttmetal.dealloc"[[C:.*]] - // CHECK: %[[C:.*]] = "ttmetal.host_read"[[C:.*]] return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Dialect/TTNN/multiple_add_with_loc.mlir b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc.mlir index ce733171c..5ba74e6f6 100644 --- a/test/ttmlir/Dialect/TTNN/multiple_add_with_loc.mlir +++ b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc.mlir @@ -3,18 +3,17 @@ #loc = loc("test_ops.py:17_0_0":0:0) module @pybuda_graph attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) { - // CHECK: #layout1 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #system>> - // CHECK: #layout2 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>> + // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>> %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2> + // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]> %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) %2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2> + // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]> %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) %4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2> + // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]> %5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7) - // CHECK: return %20, %22 : tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1> + // CHECK: return %[[R0:.*]], %[[R1:.*]] : tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1> return %3, %5 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4) } loc(#loc) } loc(#loc) diff --git a/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_grid_override.mlir b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_grid_override.mlir index 8e068bcdb..ae356c481 100644 --- a/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_grid_override.mlir +++ b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_grid_override.mlir @@ -3,19 +3,19 @@ #loc = loc("test_ops.py:17_0_0":0:0) module @pybuda_graph attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) { - // CHECK: #layout1 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #system>> - // CHECK: #layout2 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <4x4>, memref<8x8xf32, #l1_>> - // CHECK: #layout3 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>> + // CHECK: #[[LAYOUT_0:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #system>> + // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <4x4>, memref<8x8xf32, #l1_>> + // CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>> %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2> + // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]> %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) %2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2> + // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]> %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) %4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7) - // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout3> + // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_2]]> %5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7) - // CHECK: return %20, %22 : tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1> + // CHECK: return %[[R0:.*]], %[[R1:.*]] : tensor<1x32x32xf32, #[[LAYOUT_0]]>, tensor<1x32x32xf32, #[[LAYOUT_0]]> return %3, %5 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4) } loc(#loc) } loc(#loc) diff --git a/test/ttmlir/Dialect/TTNN/simple_ge.mlir b/test/ttmlir/Dialect/TTNN/simple_ge.mlir index 762345b5f..1b197bbe4 100644 --- a/test/ttmlir/Dialect/TTNN/simple_ge.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_ge.mlir @@ -5,10 +5,8 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.ge"[[C:.*]] %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Dialect/TTNN/simple_mean.mlir b/test/ttmlir/Dialect/TTNN/simple_mean.mlir index fe2f586af..945c8bef1 100644 --- a/test/ttmlir/Dialect/TTNN/simple_mean.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_mean.mlir @@ -5,10 +5,8 @@ module attributes {tt.system_desc = #tt.system_desc<[{arch = , grid // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<512x32xbf16> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.mean"[[C:.*]] %1 = "ttir.mean"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<512x32xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/simple_multiply.mlir b/test/ttmlir/Dialect/TTNN/simple_multiply.mlir index af9af72b0..a5d0cd25f 100644 --- a/test/ttmlir/Dialect/TTNN/simple_multiply.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_multiply.mlir @@ -5,10 +5,8 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Dialect/TTNN/simple_relu.mlir b/test/ttmlir/Dialect/TTNN/simple_relu.mlir index 1545e921c..ab961f0a3 100644 --- a/test/ttmlir/Dialect/TTNN/simple_relu.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_relu.mlir @@ -5,10 +5,8 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.relu"[[C:.*]] %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Dialect/TTNN/simple_subtract.mlir b/test/ttmlir/Dialect/TTNN/simple_subtract.mlir index 18cabdab0..7d35fedf0 100644 --- a/test/ttmlir/Dialect/TTNN/simple_subtract.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_subtract.mlir @@ -5,10 +5,8 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.subtract"[[C:.*]] %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Dialect/TTNN/simple_sum.mlir b/test/ttmlir/Dialect/TTNN/simple_sum.mlir index fa7e51b2a..583908a8d 100644 --- a/test/ttmlir/Dialect/TTNN/simple_sum.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_sum.mlir @@ -5,10 +5,8 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<512x32xbf16> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.sum"[[C:.*]] %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<512x32xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir b/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir index 7c17037e7..510600185 100644 --- a/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir +++ b/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir @@ -5,17 +5,14 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<512x1024xbf16> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] // Check for positive dimension attribute %1 = "ttir.softmax"(%arg0, %0) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %2 = tensor.empty() : tensor<512x1024xbf16> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] // Check for negative dimension attribute %3 = "ttir.softmax"(%1, %2) <{dimension = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %3 : tensor<512x1024xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline.mlir b/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline.mlir index 30690844d..7e4d141b6 100644 --- a/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline.mlir +++ b/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline.mlir @@ -2,13 +2,12 @@ #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: #layout2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #l1_>> + // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #l1_>> // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] -> tensor<64x128xf32, #layout2> + // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] -> tensor<64x128xf32, #[[LAYOUT_1]]> %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir b/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir index 880093c20..6d68411b3 100644 --- a/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir +++ b/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir @@ -2,13 +2,12 @@ #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: #layout1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> + // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] -> tensor<64x128xf32, #layout1> + // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] -> tensor<64x128xf32, #[[LAYOUT_1:.*]]> %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir index cdf53ffac..ab4dfa5a2 100644 --- a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir @@ -8,10 +8,8 @@ func.func @subtract(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> ten // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.subtract"[[C:.*]] %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<64x128xf32> } @@ -20,10 +18,8 @@ func.func @multiply(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> ten // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<64x128xf32> } @@ -32,10 +28,8 @@ func.func @relu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.relu"[[C:.*]] %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<64x128xf32> } @@ -44,10 +38,8 @@ func.func @ge(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64 // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.ge"[[C:.*]] %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/simple_ge.mlir b/test/ttmlir/Silicon/TTNN/simple_ge.mlir index 66081bbce..88ba8bc87 100644 --- a/test/ttmlir/Silicon/TTNN/simple_ge.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_ge.mlir @@ -8,10 +8,8 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.ge"[[C:.*]] %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/simple_multiply.mlir b/test/ttmlir/Silicon/TTNN/simple_multiply.mlir index ac0b95413..9ecd6b744 100644 --- a/test/ttmlir/Silicon/TTNN/simple_multiply.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_multiply.mlir @@ -8,10 +8,8 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/simple_relu.mlir b/test/ttmlir/Silicon/TTNN/simple_relu.mlir index 3d017ab6e..798ce6af4 100644 --- a/test/ttmlir/Silicon/TTNN/simple_relu.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_relu.mlir @@ -8,10 +8,8 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.relu"[[C:.*]] %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/simple_subtract.mlir b/test/ttmlir/Silicon/TTNN/simple_subtract.mlir index b1c63ee06..d9895efe3 100644 --- a/test/ttmlir/Silicon/TTNN/simple_subtract.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_subtract.mlir @@ -8,10 +8,8 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.subtract"[[C:.*]] %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/simple_sum.mlir b/test/ttmlir/Silicon/TTNN/simple_sum.mlir index efd346f8c..769ca9f88 100644 --- a/test/ttmlir/Silicon/TTNN/simple_sum.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_sum.mlir @@ -8,10 +8,8 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] %0 = tensor.empty() : tensor<512x32xbf16> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.sum"[[C:.*]] %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] // CHECK: "ttnn.close_device"[[C:.*]] return %1 : tensor<512x32xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir b/test/ttmlir/Silicon/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir deleted file mode 100644 index 89cbdcf18..000000000 --- a/test/ttmlir/Silicon/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir +++ /dev/null @@ -1,18 +0,0 @@ -// RUN: ttmlir-opt --ttir-load-system-desc="path=%system_desc_path%" --ttir-to-ttnn-backend-pipeline="enable-grid-set=false" %s > %t.mlir -// RUN: FileCheck %s --input-file=%t.mlir -// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn - -#any_device = #tt.operand_constraint -module attributes {} { - func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: #layout1 = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> - // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] - // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] - %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] -> tensor<64x128xf32, #layout1> - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> - // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] - // CHECK: "ttnn.close_device"[[C:.*]] - return %1 : tensor<64x128xf32> - } -}