Skip to content

Commit

Permalink
Added 2 Ops: Floor divide scalar and Floor divide scalar mode (llvm#3156
Browse files Browse the repository at this point in the history
)

- Added linalg lowering for `AtenFloorDivideScalarOp`
  - Needed `AtenDivScalarModeOp` for the decomp.
- Added linalg lowering for `AtenDivScalarModeOp`
- Moved linalg payload logic to `createDivModePayload()` since the logic
was nearly identical for both `AtenDivScalarModeOp` and
`AtenDivTensorModeOp`. Just a template function
 -  Added `AtenDivScalarModeOp` lowering for stablehlo
 

Pytorch's
[`torch.floor_divide()`](https://pytorch.org/docs/stable/generated/torch.floor_divide.html)
in a previous version (for a reason unknown to me) preformed a
truncation instead of "floor". The already implemented op
`AtenFloorDivideTensorOp` was done before this change. However, this
wasn't caught because our testcases only tested positive floor division.
I changed this to floor as well as adding a few test cases.
  • Loading branch information
IanWood1 authored Apr 15, 2024
1 parent 83cba8c commit 5708ee7
Show file tree
Hide file tree
Showing 11 changed files with 564 additions and 158 deletions.
50 changes: 50 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3397,6 +3397,56 @@ def Torch_AtenDiv_TensorModeOp : Torch_Op<"aten.div_.Tensor_mode", [
}];
}

def Torch_AtenDivScalarModeOp : Torch_Op<"aten.div.Scalar_mode", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::div.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other,
AnyTorchOptionalStringType:$rounding_mode
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDivScalarModeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenDivScalarModeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenDiv_ScalarModeOp : Torch_Op<"aten.div_.Scalar_mode", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::div_.Scalar_mode : (Tensor, Scalar, str?) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
AnyTorchScalarType:$other,
AnyTorchOptionalStringType:$rounding_mode
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDiv_ScalarModeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenDiv_ScalarModeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
188 changes: 105 additions & 83 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/APSInt.h"
#include <numeric>
#include <type_traits>

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -213,6 +214,78 @@ createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs,
return success();
}

template <typename OpT>
Value createDivModePayload(OpBuilder &b, Location loc,
const TypeConverter *converter,
ValueRange payloadArgs, OpT op,
ArrayRef<Value> operands) {
static_assert(std::is_same_v<OpT, AtenDivTensorModeOp> ||
std::is_same_v<OpT, AtenDivScalarModeOp>,
"template type must be a tensor/scalar div mode");
typename OpT::Adaptor adaptor(operands);
Type dtype = cast<RankedTensorType>(converter->convertType(op.getType()))
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(
b, loc,
std::is_same_v<OpT, AtenDivScalarModeOp> ? operands[1] : payloadArgs[1],
dtype);

Value quotient;
if (isa<mlir::FloatType>(dtype)) {
quotient = b.create<arith::DivFOp>(loc, lhs, rhs);
} else if (dtype.isUnsignedInteger()) {
quotient = b.create<arith::DivUIOp>(loc, lhs, rhs);
} else {
assert(dtype.isInteger() &&
"dtype should be an integer (signless or signed)");
quotient = b.create<arith::DivSIOp>(loc, lhs, rhs);
}

if (isa<Torch::NoneType>(op.getRoundingMode().getType()))
return quotient;

std::string roundingMode;
if (!matchPattern(op.getRoundingMode(), m_TorchConstantStr(roundingMode))) {
op.emitError("only support constant str rounding mode");
return nullptr;
}
assert((roundingMode == "trunc" || roundingMode == "floor") &&
"unsupported rounding mode");
if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
if (!isa<mlir::FloatType>(dtype)) {
// nothing to do for integers
return quotient;
}

// float
Value ceil = b.create<math::CeilOp>(loc, quotient);
Value floor = b.create<math::FloorOp>(loc, quotient);
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
quotient, cstZero);
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
}
if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator)
if (isa<mlir::FloatType>(dtype))
return b.create<math::FloorOp>(loc, quotient);
if (!dtype.isUnsignedInteger()) {
Type defaultIntToFloatType = b.getF64Type();
lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType);
rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType);
quotient = b.create<arith::DivFOp>(loc, lhs, rhs);
Value floor = b.create<math::FloorOp>(loc, quotient);
Value convert = convertScalarToDtype(b, loc, floor, dtype);
return convert;
}
}
return quotient;
}

static Value createLinalgPayloadCalculationForElementwiseOp(
OpBuilder &b, Location loc, const TypeConverter *converter,
ValueRange payloadArgs, Operation *op, ArrayRef<Value> operands) {
Expand Down Expand Up @@ -769,66 +842,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
div.emitError("unimplemented: non-floating point and non-integer dtype");
return nullptr;
}
if (auto divScalarMode = dyn_cast<AtenDivScalarModeOp>(op)) {
return createDivModePayload(b, loc, converter, payloadArgs, divScalarMode,
operands);
}
if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) {
AtenDivTensorModeOp::Adaptor adaptor(operands);
Type dtype = converter->convertType(divTensorMode.getType())
.cast<RankedTensorType>()
.getElementType();
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
Value div;
if (isa<mlir::FloatType>(dtype))
div = b.create<arith::DivFOp>(loc, lhs, rhs);
else {
if (dtype.isUnsignedInteger())
div = b.create<arith::DivUIOp>(loc, lhs, rhs);
else
div = b.create<arith::DivSIOp>(loc, lhs, rhs);
}

if (divTensorMode.getRoundingMode().getType().isa<Torch::NoneType>())
return div;

std::string roundingMode;
if (!matchPattern(divTensorMode.getRoundingMode(),
m_TorchConstantStr(roundingMode))) {
divTensorMode.emitError("only support constant str rounding mode");
return nullptr;
}
if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
if (isa<mlir::FloatType>(dtype)) {
Value ceil = b.create<math::CeilOp>(loc, div);
Value floor = b.create<math::FloorOp>(loc, div);
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
div, cstZero);
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
} else
return div;
}
if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator)
if (isa<mlir::FloatType>(dtype))
return b.create<math::FloorOp>(loc, div);
else if (!dtype.isUnsignedInteger()) {
Type defaultIntToFloatType = b.getF64Type();
lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType);
rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType);
div = b.create<arith::DivFOp>(loc, lhs, rhs);
Value floor = b.create<math::FloorOp>(loc, div);
Value convert = convertScalarToDtype(b, loc, floor, dtype);
return convert;
} else {
return div;
}
}
divTensorMode.emitError("invalid rounding mode");
return nullptr;
return createDivModePayload(b, loc, converter, payloadArgs, divTensorMode,
operands);
}

if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
Type dtype = pow.getType().cast<ValueTensorType>().getDtype();
if (!isa<mlir::FloatType>(dtype)) {
Expand Down Expand Up @@ -1579,12 +1600,13 @@ class ConvertElementwiseOp : public ConversionPattern {
if (!isa<AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp,
AtenPreluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
AtenSubTensorOp, AtenAtan2Op, AtenLerpTensorOp, AtenSigmoidOp,
AtenExpOp, AtenExpm1Op, AtenMinimumOp, AtenMaximumOp,
AtenToDtypeOp, AtenClampOp, AtenClampTensorOp, AtenRsubScalarOp,
AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp,
AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
AtenDivScalarModeOp, AtenSubTensorOp, AtenAtan2Op,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenClampTensorOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp,
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
Expand Down Expand Up @@ -2617,25 +2639,25 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenAtanhOp, AtenAcoshOp,
AtenAsinOp, AtenAsinhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp,
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenClampTensorOp,
AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp,
AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp,
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp,
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp,
AtenRemainderScalarOp, AtenFmodTensorOp, AtenRemainderTensorOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp>();
AtenDivScalarModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp,
AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp,
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op,
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp,
AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
AtenDequantizeTensorOp, AtenQuantizePerTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
Expand Down
24 changes: 15 additions & 9 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include <iostream>
#include <numeric>
#include <type_traits>

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -409,9 +409,9 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
if (!lhsType)
return op.emitError("only Tensor types supported in StableHLO");

auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
auto outType = cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));

Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
Expand All @@ -432,18 +432,23 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);

if (!isa<AtenDivTensorModeOp>(op)) {
if (!std::is_same<AtenDivTensorModeOp, AtenOpT>() &&
!std::is_same<AtenDivScalarModeOp, AtenOpT>()) {
rewriter.replaceOp(op, result);
return success();
}

AtenDivTensorModeOp divTensorModeOp =
llvm::dyn_cast<AtenDivTensorModeOp>(op.getOperation());
auto tensorOp = dyn_cast<AtenDivTensorModeOp>(op.getOperation());
auto opRoundingMode =
tensorOp
? tensorOp.getRoundingMode()
: cast<AtenDivScalarModeOp>(op.getOperation()).getRoundingMode();

std::string roundingMode;
if (!matchPattern(divTensorModeOp.getRoundingMode(),
m_TorchConstantStr(roundingMode)))
if (!matchPattern(opRoundingMode, m_TorchConstantStr(roundingMode))) {
return rewriter.notifyMatchFailure(
op, "only support constant str rounding mode");
}

// if trunc and int, do nothing
if (roundingMode == "trunc" && isa<mlir::FloatType>(outElemTy)) {
Expand Down Expand Up @@ -1845,6 +1850,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarModeOp, chlo::BroadcastDivOp);
INSERT_BINARY_MULDIV_PATTERN(AtenRemainderScalarOp, chlo::BroadcastRemOp);
#undef INSERT_BINARY_MULDIV_PATTERN

Expand Down
14 changes: 12 additions & 2 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1095,9 +1095,9 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
rhs = rewriter.create<AtenMulIntOp>(loc, rhs, alpha);
}

if (isa<AtenDivTensorModeOp>(op)) {
// None rounding mode
if (isa<AtenDivTensorModeOp, AtenDivScalarModeOp>(op)) {
if (op->getOperand(2).getType().isa<Torch::NoneType>()) {
// None rounding mode
Value quotient = rewriter.create<AtenDivOp>(loc, lhs, rhs);
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
quotient);
Expand Down Expand Up @@ -1858,6 +1858,16 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
});
}

//===----------------------------------------------------------------------===//
// AtenDivScalarModeOp
//===----------------------------------------------------------------------===//
void AtenDivScalarModeOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenDivScalarModeOp op, PatternRewriter &rewriter) {
return rewrite0DBinaryTensorOp(op, rewriter);
});
}

//===----------------------------------------------------------------------===//
// AtenNumelOp
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 5708ee7

Please sign in to comment.