Skip to content

Commit

Permalink
Adding ttir.repeat op in MLIR (#1941)
Browse files Browse the repository at this point in the history
Implementing TTIR repeat op. This PR lowers TTIR repeat op to TTNN
repeat op. Added compiler and silicon tests for repeat op.

I used this opportunity to group tests for data movement ops under the
same folder.

I also added conversation to emitC and emitC test.

Closes #1916
  • Loading branch information
sdjordjevicTT authored Jan 28, 2025
1 parent f627930 commit d739ed9
Show file tree
Hide file tree
Showing 55 changed files with 440 additions and 144 deletions.
6 changes: 5 additions & 1 deletion include/ttmlir/Conversion/TTNNToEmitC/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ emitc::OpaqueAttr convertArrayAttrToSpan(Builder &builder, ArrayAttr attr);
//
emitc::OpaqueAttr createStdNullopt(Builder &builder);

// Helper enum to differentiate between ttnn::Shape and ttnn::SimpleShape
enum class ShapeType { SimpleShape = 0, Shape = 1 };

// Create ttnn::Shape and return emitc::ExpressionOp
//
// ttnn:Shape has a couple constructors, but they are explicit and require
Expand All @@ -75,7 +78,8 @@ emitc::OpaqueAttr createStdNullopt(Builder &builder);
//
emitc::ExpressionOp createShapeOp(ConversionPatternRewriter &rewriter,
ttnn::ShapeAttr shapeAttr,
Block *containingBlock, Location loc);
Block *containingBlock, Location loc,
ShapeType shapeType = ShapeType::SimpleShape);

// Create ttnn::MemoryConfig and return emitc::CallOpaqueOp
//
Expand Down
40 changes: 39 additions & 1 deletion include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,41 @@ def TTIR_ConcatOp : TTIR_DPSOp<"concat"> {
let hasVerifier = 1;
}

def TTIR_RepeatOp : TTIR_DPSOp<"repeat"> {
let summary = "Repeat operation.";
let description = [{
The `repeat` operation creates a new tensor by replicating the input tensor's elements
along specified dimensions. The number of repetitions for each dimension is defined by
the `repeats` attribute, which must have the same rank as the input tensor.

Parameters:
- `input`: The input tensor.
- `repeats`: Specifies the number of times to repeat this tensor along each dimension.

### Example IR Usage:
```mlir
// Input tensor of shape (2, 3)
%input = ... : tensor<2x3xf32>

// Repeat each dimension twice
%empty = tensor.empty() : tensor<4x6xf32>
%repeated = "repeat"(%input, %empty) {repeat_dimensions = array<i64: 2, 2>} : (tensor<2x3xf32>, tensor<4x6xf32>) -> tensor<4x6xf32>
```
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
DenseI64ArrayAttr:$repeat_dimensions);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_RepeatInterleaveOp : TTIR_DPSOp<"repeat_interleave"> {
let summary = "Repeat interleave op.";
let description = [{
Expand Down Expand Up @@ -942,11 +977,14 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> {
// %arg0: tensor<1x1x32xf32>
%0 = tensor.empty() : tensor<1x16x32xf32>
%1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array<i32: 1, 16, 1>}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>

Note: Currently, when generating a TTNN executable, the broadcast and repeat operations share the same semantics due to the lack of tensor view support in TTNN.
As a result, the broadcast operation is lowered to a repeat operation in the TTNN compilation pipeline.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
DenseI32ArrayAttr:$broadcast_dimensions);
DenseI64ArrayAttr:$broadcast_dimensions);

let results = (outs AnyRankedTensor:$result);

Expand Down
8 changes: 6 additions & 2 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -826,11 +826,15 @@ def TTNN_ReshapeOp : TTNN_Op<"reshape"> {
def TTNN_RepeatOp : TTNN_Op<"repeat"> {
let summary = "Repeat op.";
let description = [{
Repeat the input tensor according to number of times specified in repeat_dimensions.
Returns a new tensor filled with repetition of input tensor according to number of times specified in repeat_dims.

Parameters:
- `input_tensor` (ttnn.Tensor): the input tensor.
- `repeat_dims` (number): The number of repetitions for each element.
}];

let arguments = (ins AnyRankedTensor:$input,
I32ArrayAttr:$shape);
TTNN_ShapeAttr:$repeat_dims);

let results = (outs AnyRankedTensor:$result);

Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ table ReshapeOp {
table RepeatOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
shape: [uint32];
repeat_dims: [int64];
}

table SliceOp {
Expand Down
8 changes: 4 additions & 4 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,8 +706,8 @@ class StableHLOToTTIRBroadcastInDimOpConversionPattern
::llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
::llvm::ArrayRef<int64_t> outputShape = outputType.getShape();

SmallVector<int32_t> broadcastShape =
ttmlir::utils::getBroadcastDimensions<int32_t>(inputShape,
SmallVector<int64_t> broadcastShape =
ttmlir::utils::getBroadcastDimensions<int64_t>(inputShape,
outputShape);

rewriter.replaceOpWithNewOp<mlir::tt::ttir::BroadcastOp>(
Expand Down Expand Up @@ -751,8 +751,8 @@ class StableHLOToTTIRBroadcastInDimOpConversionPattern
::llvm::ArrayRef<int64_t> inputShape = unsqueezeShape;
::llvm::ArrayRef<int64_t> outputShape = outputType.getShape();

SmallVector<int32_t> broadcastShape =
ttmlir::utils::getBroadcastDimensions<int32_t>(inputShape,
SmallVector<int64_t> broadcastShape =
ttmlir::utils::getBroadcastDimensions<int64_t>(inputShape,
outputShape);

rewriter.replaceOpWithNewOp<mlir::tt::ttir::BroadcastOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Transforms/DialectConversion.h"

#include <algorithm>
#include <cstdint>

using namespace mlir;
using namespace mlir::tt;
Expand Down Expand Up @@ -1311,8 +1312,8 @@ struct ArangeForceLastDimensionPattern
auto inputShape =
mlir::cast<mlir::RankedTensorType>(output.getType()).getShape();

SmallVector<int32_t> broadcastShape =
ttmlir::utils::getBroadcastDimensions<int32_t>(inputShape,
SmallVector<int64_t> broadcastShape =
ttmlir::utils::getBroadcastDimensions<int64_t>(inputShape,
outputShape);

output = rewriter.create<ttir::BroadcastOp>(
Expand Down
24 changes: 22 additions & 2 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,11 +686,30 @@ class BroadcastOpConversionPattern
matchAndRewrite(ttir::BroadcastOp op, ttir::BroadcastOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto shapeAttr = adaptor.getBroadcastDimensionsAttr();
ttnn::ShapeAttr shapeAttr = ttnn::ShapeAttr::get(
rewriter.getContext(), op.getBroadcastDimensions());

rewriter.replaceOpWithNewOp<ttnn::RepeatOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), shapeAttr);

return success();
}
};

class RepeatOpConversionPattern : public OpConversionPattern<ttir::RepeatOp> {
using OpConversionPattern<ttir::RepeatOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(ttir::RepeatOp op, ttir::RepeatOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ttnn::ShapeAttr repeatDimensionsAttr =
ttnn::ShapeAttr::get(rewriter.getContext(), op.getRepeatDimensions());

rewriter.replaceOpWithNewOp<ttnn::RepeatOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), rewriter.getI32ArrayAttr(shapeAttr));
adaptor.getInput(), repeatDimensionsAttr);

return success();
}
Expand Down Expand Up @@ -1373,6 +1392,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
BroadcastOpConversionPattern,
EmbeddingOpConversionPattern,
EmbeddingBackwardOpConversionPattern,
RepeatOpConversionPattern,
RepeatInterleaveOpConversionPattern,
SoftmaxOpConversionPattern,
TransposeOpConversionPattern,
Expand Down
45 changes: 43 additions & 2 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,48 @@ class ConcatOpConversionPattern
}
};

// Repeat op conversion pattern
//
class RepeatOpConversionPattern
: public TTNNToEmitCBaseOpConversionPattern<ttnn::RepeatOp> {
public:
using TTNNToEmitCBaseOpConversionPattern<
ttnn::RepeatOp>::TTNNToEmitCBaseOpConversionPattern;

LogicalResult
matchAndRewrite(ttnn::RepeatOp repeatOp, ttnn::RepeatOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ttnn::ShapeAttr repeatDims = repeatOp.getRepeatDimsAttr();

// Create ttnn::Shape() call
//
emitc::ExpressionOp shapeExpressionOp = ttnn_to_emitc::utils::createShapeOp(
rewriter, repeatDims, repeatOp->getBlock(), repeatOp.getLoc(),
ttnn_to_emitc::utils::ShapeType::Shape);

// Create operands vector
//
llvm::SmallVector<Value, 2> operands{
adaptor.getOperands()[0], // input tensor
shapeExpressionOp->getResult(0)};

// Create ArrayAttr object holding attributes and pointers to operands
//
ArrayAttr arrayAttrs = rewriter.getArrayAttr({
rewriter.getIndexAttr(0), // input tensor
rewriter.getIndexAttr(1), // ttnn::Shape
ttnn_to_emitc::utils::createStdNullopt(
rewriter) // std::nullopt for memory config
});

rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
repeatOp, this->getTypeConverter()->convertType(repeatOp.getType()),
this->convertOpName(repeatOp), arrayAttrs, nullptr, operands);

return success();
}
};

// GetDeviceOp conversion pattern
//
class GetDeviceOpConversionPattern
Expand Down Expand Up @@ -1092,8 +1134,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Tensor manipulation ops
//
patterns.add<TransposeOpConversionPattern, ConcatOpConversionPattern,
ReshapeOpConversionPattern,
DefaultOpConversionPattern<ttnn::RepeatOp>,
ReshapeOpConversionPattern, RepeatOpConversionPattern,
DefaultOpConversionPattern<ttnn::RepeatInterleaveOp>,
DefaultOpConversionPattern<ttnn::SliceOp>,
DefaultOpConversionPattern<ttnn::PermuteOp>>(typeConverter, ctx);
Expand Down
13 changes: 8 additions & 5 deletions lib/Conversion/TTNNToEmitC/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include <llvm/ADT/STLExtras.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Support/raw_ostream.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/BuiltinAttributes.h>
Expand Down Expand Up @@ -194,13 +195,15 @@ emitc::OpaqueAttr createStdNullopt(Builder &builder) {

emitc::ExpressionOp createShapeOp(ConversionPatternRewriter &rewriter,
ttnn::ShapeAttr shapeAttr,
Block *containingBlock, Location loc) {
Block *containingBlock, Location loc,
ShapeType shapeType) {
llvm::StringRef shapeTypeStr =
shapeType == ShapeType::SimpleShape ? "ttnn::SimpleShape" : "ttnn::Shape";
// Create ExpressionOp to hold multiple nested op calls, but will bundle them
// together into a single SSA value
//
emitc::ExpressionOp shapeExpressionOp = rewriter.create<emitc::ExpressionOp>(
loc, emitc::OpaqueType::get(rewriter.getContext(), "ttnn::SimpleShape"),
false);
loc, emitc::OpaqueType::get(rewriter.getContext(), shapeTypeStr), false);

// Add a block to the ExpressionOp, save current insertion point, and set
// insertion point to newly added block
Expand All @@ -222,8 +225,8 @@ emitc::ExpressionOp createShapeOp(ConversionPatternRewriter &rewriter,
// Create a ttnn::SimpleShape object
//
emitc::CallOpaqueOp ttnnShapeOp = rewriter.create<emitc::CallOpaqueOp>(
loc, emitc::OpaqueType::get(rewriter.getContext(), "ttnn::SimpleShape"),
rewriter.getStringAttr("ttnn::SimpleShape"), nullptr, nullptr,
loc, emitc::OpaqueType::get(rewriter.getContext(), shapeTypeStr),
rewriter.getStringAttr(shapeTypeStr), nullptr, nullptr,
metalShapeOp->getResults());
rewriter.create<emitc::YieldOp>(loc, ttnnShapeOp->getResult(0));

Expand Down
50 changes: 50 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,56 @@ ::mlir::LogicalResult mlir::tt::ttir::AllocOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// RepeatOp
//===----------------------------------------------------------------------===//

// RepeatOp verification.
::mlir::LogicalResult mlir::tt::ttir::RepeatOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getOutput().getType();
llvm::ArrayRef<int64_t> repeatDimensions = getRepeatDimensions();

// Input tensor and repeat dimension argument must have same rank.
if (inputType.getRank() != static_cast<int64_t>(repeatDimensions.size())) {
return emitOpError() << "Input tensor rank " << inputType.getRank()
<< " doesn't match the number of repeat dimensions "
<< repeatDimensions.size() << ".";
}

// Input and output tensors must have the same rank.
if (inputType.getRank() != outputType.getRank()) {
return emitOpError() << "Input tensor rank " << inputType.getRank()
<< " doesn't match the output tensor rank "
<< outputType.getRank() << ".";
}

// Verify output shape based on input shape and repeat dimension argument.
llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
llvm::ArrayRef<int64_t> outputShape = outputType.getShape();

for (size_t i = 0; i < inputShape.size(); i++) {
// Verify that the repeat dimension is greater than 0.
if (repeatDimensions[i] <= 0) {
return emitOpError() << "Repeat dimension at index " << i
<< " must be greater than 0.";
}

int64_t expectedDimValue = inputShape[i] * repeatDimensions[i];
if (expectedDimValue != outputShape[i]) {
return emitOpError() << "Input tensor shape ("
<< ttmlir::utils::join(inputShape, ",")
<< ") at index " << i
<< " does not repeat to output ("
<< ttmlir::utils::join(outputShape, ",")
<< ") using repeat value " << repeatDimensions[i]
<< ".";
}
}

return success();
}

//===----------------------------------------------------------------------===//
// RepeatInterleaveOp
//===----------------------------------------------------------------------===//
Expand Down
35 changes: 29 additions & 6 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"
#include "ttmlir/Utils.h"

#include <cstdint>
#include <numeric>
#include <optional>

Expand Down Expand Up @@ -412,19 +413,41 @@ ::mlir::LogicalResult mlir::tt::ttnn::ConcatOp::verify() {
::mlir::LogicalResult mlir::tt::ttnn::RepeatOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getResult().getType();
llvm::ArrayRef<int64_t> repeatDims = getRepeatDims().getShape();

auto shape = getShape();
// Verify that the input tensor and repeat_dims argument have same rank.
if (inputType.getRank() != static_cast<int64_t>(repeatDims.size())) {
return emitOpError() << "Input tensor rank " << inputType.getRank()
<< " doesn't match the number of repeat dimensions "
<< repeatDims.size() << ".";
}

// Verify that the input and output tensor have same rank.
if (inputType.getRank() != outputType.getRank()) {
return emitOpError() << "Input tensor rank " << inputType.getRank()
<< " doesn't match the output tensor rank "
<< outputType.getRank() << ".";
}

// Verify expected output shape.
auto inputShape = inputType.getShape();
auto outputShape = outputType.getShape();

for (size_t i = 0; i < shape.size(); i++) {
uint32_t dimValue = mlir::cast<IntegerAttr>(shape[i]).getInt();
for (size_t i = 0; i < repeatDims.size(); i++) {
// Verify that the repeat dimension is greater than 0.
if (repeatDims[i] <= 0) {
return emitOpError() << "Repeat dimension at index " << i
<< " must be greater than 0.";
}

int64_t dimValue = repeatDims[i];
if (inputShape[i] * dimValue != outputShape[i]) {
return emitOpError() << "Input tensor shape ("
<< ttmlir::utils::join(inputShape, ",") << ") index "
<< i << " does not repeat to output ("
<< ttmlir::utils::join(inputShape, ",")
<< ") at index " << i
<< " does not repeat to output ("
<< ttmlir::utils::join(outputShape, ",")
<< ") using repeat value " << dimValue;
<< ") using repeat value " << dimValue << ".";
}
}

Expand Down
Loading

0 comments on commit d739ed9

Please sign in to comment.