forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Encoding][NFC] Moving Encoding attr/enum to Encoding[Types|Attrs].* (i…
…ree-org#18711) The revision keeps `EncodingBase.td` simple. It follows the IREE core dialect style, which moves the declarations to `EncodingTypes.h` and implementation to `EncodingAttrs.cpp`. Signed-off-by: hanhanW <[email protected]>
- Loading branch information
Showing
10 changed files
with
377 additions
and
320 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
160 changes: 160 additions & 0 deletions
160
compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" | ||
|
||
#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/ADT/TypeSwitch.h" | ||
#include "mlir/Dialect/Affine/Utils.h" | ||
#include "mlir/Dialect/Linalg/Utils/Utils.h" | ||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h" | ||
#include "mlir/IR/Attributes.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/IR/TypeUtilities.h" | ||
#include "mlir/Interfaces/InferTypeOpInterface.h" | ||
#include "mlir/Support/LLVM.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
|
||
namespace mlir::iree_compiler::IREE::Encoding { | ||
|
||
EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex, | ||
EncodingOpType opType, ArrayRef<Type> elemTypes, | ||
ArrayRef<AffineMap> maps, | ||
std::optional<AffineMap> bcastMap, | ||
ArrayRef<int64_t> roundDimsTo) { | ||
Builder b(ctx); | ||
auto opTypeAttr = EncodingOpTypeAttr::get(ctx, opType); | ||
auto roundDimsToAttr = roundDimsTo.empty() | ||
? DenseI64ArrayAttr() | ||
: b.getDenseI64ArrayAttr(roundDimsTo); | ||
auto bcastMapAttr = bcastMap.has_value() | ||
? AffineMapAttr::get(bcastMap.value()) | ||
: AffineMapAttr(); | ||
return get(ctx, b.getIndexAttr(operandIndex), opTypeAttr, | ||
b.getTypeArrayAttr(elemTypes), b.getAffineMapArrayAttr(maps), | ||
bcastMapAttr, roundDimsToAttr); | ||
} | ||
|
||
AffineMap EncodingAttr::getMapForOperandIndex() { | ||
auto index = getOperandIndex().getValue().getZExtValue(); | ||
switch (index) { | ||
case MATMUL_LHS: | ||
case MATMUL_RHS: | ||
case MATMUL_RESULT: { | ||
auto indexingMap = | ||
llvm::cast<AffineMapAttr>(getUserIndexingMaps()[index]).getAffineMap(); | ||
if (auto bcastMap = getBcastMap()) { | ||
indexingMap = bcastMap.getAffineMap().compose(indexingMap); | ||
} | ||
return indexingMap; | ||
} | ||
default: | ||
return AffineMap(); | ||
} | ||
} | ||
|
||
std::optional<unsigned> EncodingAttr::mapDimToOperandIndex(int64_t dimPos) { | ||
return getMapForOperandIndex().getResultPosition( | ||
getAffineDimExpr(dimPos, getContext())); | ||
} | ||
|
||
MatmulNarrowDim getMatmulNarrowDim(linalg::LinalgOp linalgOp, | ||
int narrowThreshold) { | ||
linalg::ContractionDimensions cDims = | ||
linalg::inferContractionDims(linalgOp).value(); | ||
auto map = linalgOp.getIndexingMapsArray().back(); | ||
auto outType = llvm::cast<ShapedType>(linalgOp.getDpsInits()[0].getType()); | ||
auto getOutputSizeAtDimPos = [=](unsigned dimPos) -> int64_t { | ||
return outType.getDimSize( | ||
map.getResultPosition(getAffineDimExpr(dimPos, linalgOp->getContext())) | ||
.value()); | ||
}; | ||
// M or N can be empty instead of having an explicit dim size of 1 for matvec | ||
// and vecmat, so set to 1 if empty. | ||
int64_t mSize = cDims.m.empty() ? 1 : getOutputSizeAtDimPos(cDims.m[0]); | ||
int64_t nSize = cDims.n.empty() ? 1 : getOutputSizeAtDimPos(cDims.n[0]); | ||
|
||
MatmulNarrowDim narrowM, narrowN; | ||
if (!ShapedType::isDynamic(mSize) && mSize < narrowThreshold) { | ||
narrowM = {/*dim=*/MatmulNarrowDim::Dim::M, /*size=*/mSize}; | ||
} | ||
if (!ShapedType::isDynamic(nSize) && nSize < narrowThreshold) { | ||
narrowN = {/*dim=*/MatmulNarrowDim::Dim::N, /*size=*/nSize}; | ||
} | ||
|
||
return (narrowM && (!narrowN || mSize <= nSize)) ? narrowM : narrowN; | ||
} | ||
|
||
ArrayRef<int64_t> EncodingAttr::getRoundDimsToArray() { | ||
auto roundDimsTo = getRoundDimsTo(); | ||
if (!roundDimsTo) { | ||
return {}; | ||
} | ||
return llvm::cast<DenseI64ArrayAttr>(roundDimsTo).asArrayRef(); | ||
} | ||
|
||
SmallVector<Type> EncodingAttr::getElementTypesArray() { | ||
return llvm::map_to_vector(getElementTypes().getValue(), [](Attribute a) { | ||
return llvm::cast<TypeAttr>(a).getValue(); | ||
}); | ||
} | ||
|
||
EncodingAttr EncodingAttr::clone(AffineMap bcastMap) { | ||
return get(bcastMap.getContext(), getOperandIndex(), getOpType(), | ||
getElementTypes(), getUserIndexingMaps(), | ||
AffineMapAttr::get(bcastMap), getRoundDimsTo()); | ||
} | ||
|
||
MatmulNarrowDim getMatmulNarrowDim(EncodingAttr encoding) { | ||
if (encoding.getOpType().getValue() != EncodingOpType::matmul) { | ||
return {}; | ||
} | ||
ArrayRef<int64_t> roundDimsTo = encoding.getRoundDimsToArray(); | ||
if (roundDimsTo.empty()) { | ||
return {}; | ||
} | ||
int m = roundDimsTo[0]; | ||
int n = roundDimsTo[1]; | ||
if (m < n) { | ||
return {MatmulNarrowDim::Dim::M, m}; | ||
} | ||
if (n < m) { | ||
return {MatmulNarrowDim::Dim::N, n}; | ||
} | ||
return {}; | ||
} | ||
|
||
EncodingAttr getEncodingAttr(RankedTensorType type) { | ||
return dyn_cast_or_null<EncodingAttr>(type.getEncoding()); | ||
} | ||
|
||
FailureOr<linalg::ContractionDimensions> | ||
getEncodingContractionDims(EncodingAttr encoding) { | ||
auto indexingMapsAttr = encoding.getUserIndexingMaps(); | ||
SmallVector<AffineMap> indexingMaps = llvm::map_to_vector( | ||
indexingMapsAttr.getValue(), [](Attribute m) -> AffineMap { | ||
return cast<AffineMapAttr>(m).getAffineMap(); | ||
}); | ||
return linalg::inferContractionDims(indexingMaps); | ||
} | ||
|
||
std::string stringifyOperandIndex(IntegerAttr valueAttr) { | ||
auto value = valueAttr.getValue().getZExtValue(); | ||
switch (value) { | ||
case MATMUL_LHS: | ||
return "LHS"; | ||
case MATMUL_RHS: | ||
return "RHS"; | ||
case MATMUL_RESULT: | ||
return "RESULT"; | ||
default: | ||
assert(false && "invalid index"); | ||
return ""; | ||
} | ||
} | ||
|
||
} // namespace mlir::iree_compiler::IREE::Encoding |
104 changes: 104 additions & 0 deletions
104
compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#ifndef IREE_DIALECT_ENCODING_ATTRS | ||
#define IREE_DIALECT_ENCODING_ATTRS | ||
|
||
include "iree/compiler/Dialect/Encoding/IR/EncodingBase.td" | ||
include "mlir/IR/AttrTypeBase.td" | ||
include "mlir/IR/EnumAttr.td" | ||
|
||
//===---------------------------------------------------------------------===// | ||
// Data layout encoding attributes | ||
//===---------------------------------------------------------------------===// | ||
|
||
class IREEEncoding_Attr<string name, list<Trait> traits = []> | ||
: AttrDef<IREEEncoding_Dialect, name, traits>; | ||
|
||
class IREEEncoding_I32EnumAttr<string name, string summary, list<I32EnumAttrCase> cases> | ||
: I32EnumAttr<name, summary, cases> { | ||
let cppNamespace = "::mlir::iree_compiler::IREE::Encoding"; | ||
let genSpecializedAttr = 0; | ||
} | ||
|
||
class IREEEncoding_EnumAttr<EnumAttrInfo enumInfo, string name = ""> | ||
: EnumAttr<IREEEncoding_Dialect, enumInfo, name>; | ||
|
||
// Enums for tagging operand operation in an EncodingAttr | ||
def MATMUL : I32EnumAttrCase<"matmul", 0>; | ||
def CONV : I32EnumAttrCase<"conv", 1>; | ||
|
||
def EncodingOpType : IREEEncoding_I32EnumAttr<"EncodingOpType", | ||
"Tracks the type of operation of the operand.", [ | ||
MATMUL, | ||
CONV, | ||
]>; | ||
|
||
def EncodingOpTypeAttr: | ||
IREEEncoding_EnumAttr<EncodingOpType, "optype">; | ||
|
||
def EncodingAttr : | ||
IREEEncoding_Attr<"Encoding"> { | ||
let mnemonic = "encoding"; | ||
let summary = [{information to decide how to data-tile a tensor}]; | ||
let description = [{ | ||
This attribute describes the change in the layout for | ||
a given tensor to execute subsequent operations on | ||
the tiled layout. The encoding serves as a way to | ||
represent the change in the way the data is laid out in | ||
memory without changing the logical rank/extent of | ||
the tensor itself. When required, the encoding | ||
can be used to explicitly manifest the layout change | ||
through operations like pack/unpack. | ||
}]; | ||
|
||
let assemblyFormat = "`<` struct(params) `>`"; | ||
|
||
let parameters = (ins | ||
AttrParameter<"IntegerAttr", "this tensor operand's index in the parameter list">:$operand_index, | ||
AttrParameter<"EncodingOpTypeAttr", "operand type">:$op_type, | ||
AttrParameter<"ArrayAttr", "element types of the user's operands">:$element_types, | ||
OptionalParameter<"ArrayAttr", "Indexing maps of the operation using this tensor">:$user_indexing_maps, | ||
OptionalParameter<"AffineMapAttr", "Indexing map that represents the broadcasting dims in the producer">:$bcast_map, | ||
// TODO(hanchung): The round_dims_to parameter can be revisited. We explicitly map them to M,N,K dimension for now. | ||
OptionalParameter<"DenseArrayAttr", "Values for padding M,N,K dimensions">:$round_dims_to | ||
); | ||
|
||
let builders = [ | ||
AttrBuilder<(ins "int64_t":$operandIndex, | ||
"EncodingOpType":$opType, | ||
"ArrayRef<Type>":$elemTypes, | ||
CArg<"ArrayRef<AffineMap>", "{}">:$maps, | ||
CArg<"std::optional<AffineMap>", "{}">:$bcastMap, | ||
CArg<"ArrayRef<int64_t>", "{}">:$roundDimsTo)> | ||
]; | ||
|
||
let extraClassDeclaration = [{ | ||
/// Returns the bcast_map composed with the user_indexing_map for the | ||
/// operand_index. The dimensions of the returned map are those of the | ||
/// data-tiled op's iteration space, and the results of the map are in | ||
/// the domain of the encoded tensor type. | ||
AffineMap getMapForOperandIndex(); | ||
|
||
/// Given the dim position of the encoding `user_indexing_maps`, returns the | ||
/// matching index of the given encoding's tensor, using getMapForOperandIndex | ||
/// bcast_map and user_indexing_map. | ||
std::optional<unsigned> mapDimToOperandIndex(int64_t dimPos); | ||
|
||
/// Returns an integer array with values in `round_dims_to`. | ||
ArrayRef<int64_t> getRoundDimsToArray(); | ||
|
||
/// Returns a vector with values in `element_types`. | ||
SmallVector<Type> getElementTypesArray(); | ||
|
||
/// Clones an encoding with a new bcast_map | ||
EncodingAttr clone(AffineMap bcastMap); | ||
}]; | ||
|
||
let genVerifyDecl = 0; | ||
} | ||
|
||
#endif // IREE_DIALECT_ENCODING_ATTRS |
Oops, something went wrong.