Skip to content

Commit

Permalink
Modified parameter names for transpose.
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimirjovanovicTT committed Aug 9, 2024
1 parent 286e354 commit 9294a7b
Show file tree
Hide file tree
Showing 12 changed files with 34 additions and 34 deletions.
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ def TTIR_TransposeOp : TTIR_DPSOp<"transpose"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dimension0,
SI32Attr:$dimension1,
SI32Attr:$dimension2,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);
Expand Down
4 changes: 2 additions & 2 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ def TTNN_TransposeOp : TTNN_NamedDPSOp<"transpose"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dimension1,
SI32Attr:$dimension2);
SI32Attr:$dimension0,
SI32Attr:$dimension1);

let results = (outs AnyRankedTensor:$result);

Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ table SoftmaxOp {
table TransposeOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
dimension0: int32;
dimension1: int32;
dimension2: int32;
}

// ANCHOR: adding_an_op_matmul_fbs
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ class TransposeOpConversionPattern
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ttnn::TransposeOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getOutput(), adaptor.getDimension1(),
adaptor.getDimension2());
adaptor.getInput(), adaptor.getOutput(), adaptor.getDimension0(),
adaptor.getDimension1());
return success();
}
};
Expand Down
16 changes: 8 additions & 8 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,30 +51,30 @@ ::mlir::LogicalResult mlir::tt::ttir::TransposeOp::verify() {
::mlir::RankedTensorType outputType = getOutput().getType();
auto inputShape = inputType.getShape();
auto outputShape = outputType.getShape();
int32_t dim0 = getDimension0();
int32_t dim1 = getDimension1();
int32_t dim2 = getDimension2();
if (inputType.getRank() < 2) {
return emitOpError("Input must be at least a 2D tensor");
}
if (inputType.getRank() != outputType.getRank()) {
return emitOpError("Input must have the same rank as output");
}
if (dim1 >= inputType.getRank() || dim1 < -inputType.getRank()) {
if (dim0 >= inputType.getRank() || dim0 < -inputType.getRank()) {
return emitOpError(
"Dimension 1 attribute must be within the bounds of the input tensor");
}
if (dim2 >= inputType.getRank() || dim2 < -inputType.getRank()) {
if (dim1 >= inputType.getRank() || dim1 < -inputType.getRank()) {
return emitOpError(
"Dimension 2 attribute must be within the bounds of the input tensor");
}
if (dim0 < 0) {
dim0 += inputType.getRank();
}
if (dim1 < 0) {
dim1 += inputType.getRank();
}
if (dim2 < 0) {
dim2 += inputType.getRank();
}
if (outputShape[dim1] != inputShape[dim2] ||
outputShape[dim2] != inputShape[dim1]) {
if (outputShape[dim0] != inputShape[dim1] ||
outputShape[dim1] != inputShape[dim0]) {
return emitOpError("Input-output transpose dimension mismatch.");
}
return success();
Expand Down
16 changes: 8 additions & 8 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,30 +55,30 @@ ::mlir::LogicalResult mlir::tt::ttnn::TransposeOp::verify() {
::mlir::RankedTensorType outputType = getOutput().getType();
auto inputShape = inputType.getShape();
auto outputShape = outputType.getShape();
int32_t dim0 = getDimension0();
int32_t dim1 = getDimension1();
int32_t dim2 = getDimension2();
if (inputType.getRank() < 2) {
return emitOpError("Input must be at least a 2D tensor");
}
if (inputType.getRank() != outputType.getRank()) {
return emitOpError("Input must have the same rank as output");
}
if (dim1 >= inputType.getRank() || dim1 < -inputType.getRank()) {
if (dim0 >= inputType.getRank() || dim0 < -inputType.getRank()) {
return emitOpError(
"Dimension 1 attribute must be within the bounds of the input tensor");
}
if (dim2 >= inputType.getRank() || dim2 < -inputType.getRank()) {
if (dim1 >= inputType.getRank() || dim1 < -inputType.getRank()) {
return emitOpError(
"Dimension 2 attribute must be within the bounds of the input tensor");
}
if (dim0 < 0) {
dim0 += inputType.getRank();
}
if (dim1 < 0) {
dim1 += inputType.getRank();
}
if (dim2 < 0) {
dim2 += inputType.getRank();
}
if (outputShape[dim1] != inputShape[dim2] ||
outputShape[dim2] != inputShape[dim1]) {
if (outputShape[dim0] != inputShape[dim1] ||
outputShape[dim1] != inputShape[dim0]) {
return emitOpError("Input-output transpose dimension mismatch.");
}
return success();
Expand Down
6 changes: 3 additions & 3 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,11 @@ createTransposeOp(FlatbufferObjectCache &cache, TransposeOp op) {
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
auto out = cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getResult()));
int32_t dimension0 = op.getDimension0();
int32_t dimension1 = op.getDimension1();
int32_t dimension2 = op.getDimension2();

return ::tt::target::ttnn::CreateTransposeOp(*cache.fbb, in, out, dimension1,
dimension2);
return ::tt::target::ttnn::CreateTransposeOp(*cache.fbb, in, out, dimension0,
dimension1);
}

template <typename SoftmaxOp>
Expand Down
10 changes: 5 additions & 5 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,18 +285,18 @@ run(::tt::target::ttnn::TransposeOp const *op, ::ttnn::device::Device &device,
std::unordered_map<std::uint32_t, ::ttnn::Tensor *> &liveTensors,
std::list<::ttnn::Tensor> &tensorPool) {
::ttnn::Tensor &in = *liveTensors.at(op->in()->global_id());
int32_t dimension0 = op->dimension0();
int32_t dimension1 = op->dimension1();
int32_t dimension2 = op->dimension2();
auto input_rank = in.get_shape().rank();
std::vector<int> dimensionOrder(input_rank);
std::iota(dimensionOrder.begin(), dimensionOrder.end(), 0);
if (dimension0 < 0) {
dimension0 += input_rank;
}
if (dimension1 < 0) {
dimension1 += input_rank;
}
if (dimension2 < 0) {
dimension2 += input_rank;
}
std::swap(dimensionOrder[dimension1], dimensionOrder[dimension2]);
std::swap(dimensionOrder[dimension0], dimensionOrder[dimension1]);
tensorPool.push_back(
::ttnn::operations::data_movement::permute(in, dimensionOrder));
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/transpose/simple_transpose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module attributes {tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid
func.func @forward(%arg0: tensor<64x128xbf16>) -> tensor<128x64xbf16> {
%0 = tensor.empty() : tensor<128x64xbf16>
// CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]]
%1 = "ttir.transpose"(%arg0, %0) <{dimension1 = 0 : si32, dimension2 = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16>
%1 = "ttir.transpose"(%arg0, %0) <{dimension0 = 0 : si32, dimension1 = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16>
return %1 : tensor<128x64xbf16>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module attributes {tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid
func.func @forward(%arg0: tensor<8x16xbf16>) -> tensor<16x8xbf16> {
%0 = tensor.empty() : tensor<16x8xbf16>
// CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]]
%1 = "ttir.transpose"(%arg0, %0) <{dimension1 = 1 : si32, dimension2 = 0 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<8x16xbf16>, tensor<16x8xbf16>) -> tensor<16x8xbf16>
%1 = "ttir.transpose"(%arg0, %0) <{dimension0 = 1 : si32, dimension1 = 0 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<8x16xbf16>, tensor<16x8xbf16>) -> tensor<16x8xbf16>
return %1 : tensor<16x8xbf16>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module attributes {tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid
func.func @forward(%arg0: tensor<8x8xbf16>) -> tensor<8x8xbf16> {
%0 = tensor.empty() : tensor<8x8xbf16>
// CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]]
%1 = "ttir.transpose"(%arg0, %0) <{dimension1 = 0 : si32, dimension2 = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<8x8xbf16>, tensor<8x8xbf16>) -> tensor<8x8xbf16>
%1 = "ttir.transpose"(%arg0, %0) <{dimension0 = 0 : si32, dimension1 = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<8x8xbf16>, tensor<8x8xbf16>) -> tensor<8x8xbf16>
return %1 : tensor<8x8xbf16>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module attributes {tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid
func.func @forward(%arg0: tensor<8x8xbf16>) -> tensor<8x8xbf16> {
%0 = tensor.empty() : tensor<8x8xbf16>
// CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]]
%1 = "ttir.transpose"(%arg0, %0) <{dimension1 = -1 : si32, dimension2 = -2 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<8x8xbf16>, tensor<8x8xbf16>) -> tensor<8x8xbf16>
%1 = "ttir.transpose"(%arg0, %0) <{dimension0 = -1 : si32, dimension1 = -2 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<8x8xbf16>, tensor<8x8xbf16>) -> tensor<8x8xbf16>
return %1 : tensor<8x8xbf16>
}
}

0 comments on commit 9294a7b

Please sign in to comment.