Skip to content

Commit

Permalink
Enforce to_layout conversion for ops constrained as tile only (#383)
Browse files Browse the repository at this point in the history
  • Loading branch information
nsmithtt authored Aug 15, 2024
1 parent 502d8ee commit fb9cba1
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 33 deletions.
28 changes: 28 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,34 @@ parseDimensionList(::mlir::AsmParser &odsParser,
::llvm::SmallVector<int64_t> &dimensions) {
return odsParser.parseDimensionList(dimensions, false, false);
}

inline DataType elementTypeToDataType(Type elementType) {
DataType dtype = DataType::Float32;
if (isa<FloatType>(elementType)) {
auto floatType = mlir::cast<FloatType>(elementType);
if (floatType.isF32()) {
dtype = DataType::Float32;
} else if (floatType.isF16()) {
dtype = DataType::Float16;
} else if (floatType.isBF16()) {
dtype = DataType::BFloat16;
} else {
assert(false && "unsupported float type");
}
} else if (isa<IntegerType>(elementType)) {
auto intType = mlir::cast<IntegerType>(elementType);
if (intType.getWidth() == 32) {
dtype = DataType::UInt32;
} else if (intType.getWidth() == 16) {
dtype = DataType::UInt16;
} else if (intType.getWidth() == 8) {
dtype = DataType::UInt8;
} else {
assert(false && "unsupported integer type");
}
}
return dtype;
}
} // namespace mlir::tt

#define GET_ATTRDEF_CLASSES
Expand Down
5 changes: 5 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> {
GridAttr grid = {},
ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals = {{0, -1}},
OOBVal oobVal = OOBVal::Undef);
static LayoutAttr get(::mlir::MLIRContext *context,
RankedTensorType ty,
MemorySpace memorySpace,
Type elementType);
LayoutAttr withGrid(::mlir::MLIRContext *context, ArrayRef<int64_t> tensorShape, GridAttr grid, ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals = {{0, -1}});
LayoutAttr withGrid(::mlir::MLIRContext *context,
RankedTensorType ty,
Expand Down Expand Up @@ -275,6 +279,7 @@ def TT_Tile : TT_Type<"Tile", "tile", [MemRefElementTypeInterface]> {
let assemblyFormat = "`<` custom<DimensionList>($shape) `,` $dataType `>`";

let extraClassDeclaration = [{
static TileType get(::mlir::MLIRContext *context, Type elementType, ArrayRef<int64_t> shape = {32, 32});
SmallVector<int64_t> getScalarShape(SmallVector<int64_t> tiledShape) const;
SmallVector<int64_t> getTiledShape(SmallVector<int64_t> scalarShape) const;
uint64_t getSizeBytes() const;
Expand Down
28 changes: 0 additions & 28 deletions include/ttmlir/Target/Utils/MLIRToFlatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,34 +260,6 @@ toFlatbuffer(FlatbufferObjectCache &cache, GridAttr tensorGrid,
return coreRangeSet;
}

inline DataType elementTypeToDataType(Type elementType) {
DataType dtype = DataType::Float32;
if (isa<FloatType>(elementType)) {
auto floatType = mlir::cast<FloatType>(elementType);
if (floatType.isF32()) {
dtype = DataType::Float32;
} else if (floatType.isF16()) {
dtype = DataType::Float16;
} else if (floatType.isBF16()) {
dtype = DataType::BFloat16;
} else {
assert(false && "unsupported float type");
}
} else if (isa<IntegerType>(elementType)) {
auto intType = mlir::cast<IntegerType>(elementType);
if (intType.getWidth() == 32) {
dtype = DataType::UInt32;
} else if (intType.getWidth() == 16) {
dtype = DataType::UInt16;
} else if (intType.getWidth() == 8) {
dtype = DataType::UInt8;
} else {
assert(false && "unsupported integer type");
}
}
return dtype;
}

template <typename AttrType, typename ValueType>
struct ArrayAttrToFlatbufferSerializer {
static flatbuffers::Offset<flatbuffers::Vector<ValueType>>
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,13 @@ LayoutAttr LayoutAttr::get(
collapseIntervals, oobVal);
}

LayoutAttr LayoutAttr::get(::mlir::MLIRContext *context, RankedTensorType ty,
MemorySpace memorySpace, Type elementType) {
assert(ty);
return get(context, ty.getShape(), elementType, memorySpace, {}, {{0, -1}},
OOBVal::Undef);
}

// From the logical shape of the tensor and the affine map of the layout,
// compute the physical shape of the tensor, i.e the shape of the tensor
// after the dimensions have been collapsed onto a grid.
Expand Down Expand Up @@ -472,6 +479,11 @@ TileType::verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
return ::mlir::success();
}

TileType TileType::get(::mlir::MLIRContext *context, Type elementType,
ArrayRef<int64_t> shape) {
return get(context, shape, elementTypeToDataType(elementType));
}

llvm::SmallVector<int64_t>
TileType::getScalarShape(SmallVector<int64_t> tiledShape) const {
assert(tiledShape.size() >= 2 && "expected at least 2D shape");
Expand Down
20 changes: 15 additions & 5 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,15 +418,22 @@ class TTIRLayoutTensorTypeRewriter : public RewritePattern {

static std::optional<Value> createToLayoutOp(PatternRewriter &rewriter,
Location loc, Value input,
MemorySpace desiredMemorySpace) {
MemorySpace desiredMemorySpace,
bool tiled) {
auto ty = mlir::cast<RankedTensorType>(input.getType());
auto currLayout = mlir::cast<LayoutAttr>(ty.getEncoding());
auto currMemorySpace = currLayout.getMemorySpace();
if (currMemorySpace == desiredMemorySpace) {
auto currElementType = currLayout.getElementType();
auto desiredElementType =
tiled ? rewriter.getType<TileType>(ty.getElementType())
: ty.getElementType();
if (currMemorySpace == desiredMemorySpace &&
currElementType == desiredElementType) {
return std::nullopt;
}

auto desiredLayout = rewriter.getAttr<LayoutAttr>(ty, desiredMemorySpace);
auto desiredLayout =
rewriter.getAttr<LayoutAttr>(ty, desiredMemorySpace, desiredElementType);
auto output = rewriter.create<tensor::EmptyOp>(
loc, ty.getShape(), ty.getElementType(), desiredLayout);

Expand All @@ -444,7 +451,9 @@ static std::optional<Value>
createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input,
OperandConstraint operandConstraint) {
auto desiredMemorySpace = uppermostMemorySpace(operandConstraint);
return createToLayoutOp(rewriter, loc, input, desiredMemorySpace);
bool tiled =
!bitEnumContainsAny(operandConstraint, OperandConstraint::Scalar);
return createToLayoutOp(rewriter, loc, input, desiredMemorySpace, tiled);
}

class TTIRLayoutDPSOperandsRewriter
Expand Down Expand Up @@ -516,8 +525,9 @@ class TTIRLayoutFuncReturnRewriter
for (auto &operand : op->getOpOperands()) {
// Leave the return values in initMemorySpace, optimizer might decide
// otherwise
bool tiled = false;
if (auto layout = createToLayoutOp(rewriter, op.getLoc(), operand.get(),
initMemorySpace);
initMemorySpace, tiled);
layout) {
rewriter.modifyOpInPlace(
op, [&]() { op.setOperand(operand.getOperandNumber(), *layout); });
Expand Down
1 change: 1 addition & 0 deletions test/ttmlir/Dialect/TTNN/simple_matmul.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
// CHECK: #[[TILED_LAYOUT:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> {
%0 = tensor.empty() : tensor<64x96xbf16>
Expand Down

0 comments on commit fb9cba1

Please sign in to comment.