From edd542d32d0ecbebb5f3631dedd92a221fa3225a Mon Sep 17 00:00:00 2001 From: Toshiya Kobayashi Date: Tue, 9 Jan 2024 16:13:53 +0100 Subject: [PATCH] [KIE-775] drools executable-model fails with a bind variable to a calculation result of int and BigDecimal (#5636) (cherry picked from commit fe2dbacafed4f7c4809e91f829442cc6da348df5) --- .../ArithmeticCoercedExpression.java | 4 + .../generator/drlxparse/ConstraintParser.java | 54 +++++++++-- .../expressiontyper/ExpressionTyper.java | 50 +++++++++- .../bigdecimaltest/BigDecimalTest.java | 92 +++++++++++++++++++ 4 files changed, 189 insertions(+), 11 deletions(-) diff --git a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ArithmeticCoercedExpression.java b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ArithmeticCoercedExpression.java index 89d7d77b9c7..dfd241615fe 100644 --- a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ArithmeticCoercedExpression.java +++ b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ArithmeticCoercedExpression.java @@ -51,6 +51,10 @@ public ArithmeticCoercedExpression(TypedExpression left, TypedExpression right, this.operator = operator; } + /* + * This coercion only deals with String vs Numeric types. + * BigDecimal arithmetic operation is handled by ExpressionTyper.convertArithmeticBinaryToMethodCall() + */ public ArithmeticCoercedExpressionResult coerce() { if (!requiresCoercion()) { diff --git a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ConstraintParser.java b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ConstraintParser.java index 0dfe9a472ff..387c9b051eb 100644 --- a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ConstraintParser.java +++ b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/drlxparse/ConstraintParser.java @@ -94,6 +94,9 @@ import static org.drools.model.codegen.execmodel.generator.ConstraintUtil.GREATER_THAN_PREFIX; import static org.drools.model.codegen.execmodel.generator.ConstraintUtil.LESS_OR_EQUAL_PREFIX; import static org.drools.model.codegen.execmodel.generator.ConstraintUtil.LESS_THAN_PREFIX; +import static org.drools.model.codegen.execmodel.generator.expressiontyper.ExpressionTyper.convertArithmeticBinaryToMethodCall; +import static org.drools.model.codegen.execmodel.generator.expressiontyper.ExpressionTyper.getBinaryTypeAfterConversion; +import static org.drools.model.codegen.execmodel.generator.expressiontyper.ExpressionTyper.shouldConvertArithmeticBinaryToMethodCall; import static org.drools.util.StringUtils.lcFirstForBean; import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.THIS_PLACEHOLDER; import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.createConstraintCompiler; @@ -196,11 +199,23 @@ private void logWarnIfNoReactOnCausedByVariableFromDifferentPattern(DrlxParseRes } private void addDeclaration(DrlxExpression drlx, SingleDrlxParseSuccess singleResult, String bindId) { - DeclarationSpec decl = context.addDeclaration( bindId, singleResult.getLeftExprTypeBeforeCoercion() ); + DeclarationSpec decl = context.addDeclaration(bindId, getDeclarationType(drlx, singleResult)); if (drlx.getExpr() instanceof NameExpr) { decl.setBoundVariable( PrintUtil.printNode(drlx.getExpr()) ); } else if (drlx.getExpr() instanceof BinaryExpr) { - decl.setBoundVariable(PrintUtil.printNode(drlx.getExpr().asBinaryExpr().getLeft())); + Expression leftMostExpression = getLeftMostExpression(drlx.getExpr().asBinaryExpr()); + decl.setBoundVariable(PrintUtil.printNode(leftMostExpression)); + if (singleResult.getExpr() instanceof MethodCallExpr) { + // BinaryExpr was converted to MethodCallExpr. Create a TypedExpression for the leftmost expression of the BinaryExpr + ExpressionTyperContext expressionTyperContext = new ExpressionTyperContext(); + ExpressionTyper expressionTyper = new ExpressionTyper(context, singleResult.getPatternType(), bindId, false, expressionTyperContext); + TypedExpressionResult leftTypedExpressionResult = expressionTyper.toTypedExpression(leftMostExpression); + Optional optLeft = leftTypedExpressionResult.getTypedExpression(); + if (!optLeft.isPresent()) { + throw new IllegalStateException("Cannot create TypedExpression for " + drlx.getExpr().asBinaryExpr().getLeft()); + } + singleResult.setBoundExpr(optLeft.get()); + } } decl.setBelongingPatternDescr(context.getCurrentPatternDescr()); singleResult.setExprBinding( bindId ); @@ -210,6 +225,24 @@ private void addDeclaration(DrlxExpression drlx, SingleDrlxParseSuccess singleRe } } + private static Class getDeclarationType(DrlxExpression drlx, SingleDrlxParseSuccess singleResult) { + if (drlx.getBind() != null && drlx.getExpr() instanceof EnclosedExpr) { + // in case of enclosed, bind type should be the calculation result type + // If drlx.getBind() == null, a bind variable is inside the enclosed expression, leave it to the default behavior + return (Class)singleResult.getExprType(); + } else { + return singleResult.getLeftExprTypeBeforeCoercion(); + } + } + + private Expression getLeftMostExpression(BinaryExpr binaryExpr) { + Expression left = binaryExpr.getLeft(); + if (left instanceof BinaryExpr) { + return getLeftMostExpression((BinaryExpr) left); + } + return left; + } + /* This is the entry point for Constraint Transformation from a parsed MVEL constraint to a Java Expression @@ -656,17 +689,16 @@ private DrlxParseResult parseBinaryExpr(BinaryExpr binaryExpr, Class patternT Expression combo; - boolean arithmeticExpr = ARITHMETIC_OPERATORS.contains(operator); boolean isBetaConstraint = right.getExpression() != null && hasDeclarationFromOtherPattern( expressionTyperContext ); boolean requiresSplit = operator == BinaryExpr.Operator.AND && binaryExpr.getRight() instanceof HalfBinaryExpr && !isBetaConstraint; + Type exprType = isBooleanOperator( operator ) ? boolean.class : left.getType(); + if (equalityExpr) { combo = getEqualityExpression( left, right, operator ).expression; - } else if (arithmeticExpr && (left.isBigDecimal())) { - ConstraintCompiler constraintCompiler = createConstraintCompiler(this.context, of(patternType)); - CompiledExpressionResult compiledExpressionResult = constraintCompiler.compileExpression(binaryExpr); - - combo = compiledExpressionResult.getExpression(); + } else if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) { + combo = convertArithmeticBinaryToMethodCall(binaryExpr, of(patternType), this.context); + exprType = getBinaryTypeAfterConversion(left.getType(), right.getType()); } else { if (left.getExpression() == null || right.getExpression() == null) { return new DrlxParseFail(new ParseExpressionErrorResult(drlxExpr)); @@ -694,7 +726,7 @@ private DrlxParseResult parseBinaryExpr(BinaryExpr binaryExpr, Class patternT constraintType = Index.ConstraintType.FORALL_SELF_JOIN; } - return new SingleDrlxParseSuccess(patternType, bindingId, combo, isBooleanOperator( operator ) ? boolean.class : left.getType()) + return new SingleDrlxParseSuccess(patternType, bindingId, combo, exprType) .setDecodeConstraintType( constraintType ) .setUsedDeclarations( expressionTyperContext.getUsedDeclarations() ) .setUsedDeclarationsOnLeft( usedDeclarationsOnLeft ) @@ -1007,4 +1039,8 @@ private Optional convertBigDecimalArithmetic(MethodCallExpr metho } return Optional.empty(); } + + public static boolean isArithmeticOperator(BinaryExpr.Operator operator) { + return ARITHMETIC_OPERATORS.contains(operator); + } } diff --git a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/expressiontyper/ExpressionTyper.java b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/expressiontyper/ExpressionTyper.java index eb101f499cf..e616ac1c5e8 100644 --- a/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/expressiontyper/ExpressionTyper.java +++ b/drools-model/drools-model-codegen/src/main/java/org/drools/model/codegen/execmodel/generator/expressiontyper/ExpressionTyper.java @@ -22,6 +22,7 @@ import java.lang.reflect.Modifier; import java.lang.reflect.ParameterizedType; import java.lang.reflect.TypeVariable; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -89,6 +90,8 @@ import org.drools.mvel.parser.ast.expr.OOPathExpr; import org.drools.mvel.parser.ast.expr.PointFreeExpr; import org.drools.mvel.parser.printer.PrintUtil; +import org.drools.mvelcompiler.CompiledExpressionResult; +import org.drools.mvelcompiler.ConstraintCompiler; import org.drools.mvelcompiler.util.BigDecimalArgumentCoercion; import org.drools.util.MethodUtils; import org.drools.util.TypeResolver; @@ -99,6 +102,7 @@ import static java.util.Optional.empty; import static java.util.Optional.of; import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.THIS_PLACEHOLDER; +import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.createConstraintCompiler; import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.findRootNodeViaParent; import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.getClassFromContext; import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.getClassFromType; @@ -113,6 +117,7 @@ import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.toStringLiteral; import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.transformDrlNameExprToNameExpr; import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.trasformHalfBinaryToBinary; +import static org.drools.model.codegen.execmodel.generator.drlxparse.ConstraintParser.isArithmeticOperator; import static org.drools.model.codegen.execmodel.generator.expressiontyper.FlattenScope.flattenScope; import static org.drools.model.codegen.execmodel.generator.expressiontyper.FlattenScope.transformFullyQualifiedInlineCastExpr; import static org.drools.mvel.parser.MvelParser.parseType; @@ -229,7 +234,14 @@ private Optional toTypedExpressionRec(Expression drlxExpr) { right = coerced.getCoercedRight(); final BinaryExpr combo = new BinaryExpr(left.getExpression(), right.getExpression(), operator); - return of(new TypedExpression(combo, left.getType())); + + if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) { + Expression expression = convertArithmeticBinaryToMethodCall(combo, of(typeCursor), ruleContext); + java.lang.reflect.Type binaryType = getBinaryTypeAfterConversion(left.getType(), right.getType()); + return of(new TypedExpression(expression, binaryType)); + } else { + return of(new TypedExpression(combo, left.getType())); + } } if (drlxExpr instanceof HalfBinaryExpr) { @@ -800,7 +812,38 @@ private TypedExpressionCursor binaryExpr(BinaryExpr binaryExpr) { TypedExpression rightTypedExpression = right.getTypedExpression() .orElseThrow(() -> new NoSuchElementException("TypedExpressionResult doesn't contain TypedExpression!")); binaryExpr.setRight(rightTypedExpression.getExpression()); - return new TypedExpressionCursor(binaryExpr, getBinaryType(leftTypedExpression, rightTypedExpression, binaryExpr.getOperator())); + if (shouldConvertArithmeticBinaryToMethodCall(binaryExpr.getOperator(), leftTypedExpression.getType(), rightTypedExpression.getType())) { + Expression compiledExpression = convertArithmeticBinaryToMethodCall(binaryExpr, leftTypedExpression.getOriginalPatternType(), ruleContext); + java.lang.reflect.Type binaryType = getBinaryTypeAfterConversion(leftTypedExpression.getType(), rightTypedExpression.getType()); + return new TypedExpressionCursor(compiledExpression, binaryType); + } else { + java.lang.reflect.Type binaryType = getBinaryType(leftTypedExpression, rightTypedExpression, binaryExpr.getOperator()); + return new TypedExpressionCursor(binaryExpr, binaryType); + } + } + + /* + * Converts arithmetic binary expression (including coercion) to method call using ConstraintCompiler. + * This method can be generic, so we may centralize the calls in drools-model + */ + public static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryExpr, Optional> originalPatternType, RuleContext ruleContext) { + ConstraintCompiler constraintCompiler = createConstraintCompiler(ruleContext, originalPatternType); + CompiledExpressionResult compiledExpressionResult = constraintCompiler.compileExpression(printNode(binaryExpr)); + return compiledExpressionResult.getExpression(); + } + + /* + * BigDecimal arithmetic operations should be converted to method calls. We may also apply this to BigInteger. + */ + public static boolean shouldConvertArithmeticBinaryToMethodCall(BinaryExpr.Operator operator, java.lang.reflect.Type leftType, java.lang.reflect.Type rightType) { + return isArithmeticOperator(operator) && (leftType.equals(BigDecimal.class) || rightType.equals(BigDecimal.class)); + } + + /* + * After arithmetic to method call conversion, BigDecimal should take precedence regardless of left or right. We may also apply this to BigInteger. + */ + public static java.lang.reflect.Type getBinaryTypeAfterConversion(java.lang.reflect.Type leftType, java.lang.reflect.Type rightType) { + return (leftType.equals(BigDecimal.class) || rightType.equals(BigDecimal.class)) ? BigDecimal.class : leftType; } private java.lang.reflect.Type getBinaryType(TypedExpression leftTypedExpression, TypedExpression rightTypedExpression, Operator operator) { @@ -907,6 +950,9 @@ private void promoteBigDecimalParameters(MethodCallExpr methodCallExpr, Class[] Expression argumentExpression = methodCallExpr.getArgument(i); if (argumentType != actualArgumentType) { + // unbind the original argumentExpression first, otherwise setArgument() will remove the argumentExpression from coercedExpression.childrenNodes + // It will result in failing to find DrlNameExpr in AST at DrlsParseUtil.transformDrlNameExprToNameExpr() + methodCallExpr.replace(argumentExpression, new NameExpr("placeholder")); Expression coercedExpression = new BigDecimalArgumentCoercion().coercedArgument(argumentType, actualArgumentType, argumentExpression); methodCallExpr.setArgument(i, coercedExpression); } diff --git a/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/bigdecimaltest/BigDecimalTest.java b/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/bigdecimaltest/BigDecimalTest.java index 50655e281cd..5ff7510398e 100644 --- a/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/bigdecimaltest/BigDecimalTest.java +++ b/drools-model/drools-model-codegen/src/test/java/org/drools/model/codegen/execmodel/bigdecimaltest/BigDecimalTest.java @@ -720,4 +720,96 @@ public void bigDecimalEqualityWithDifferentScale_shouldBeEqual() { // BigDecimal("1.0") and BigDecimal("1.00") are considered as equal because exec-model uses EvaluationUtil.equals() which is based on compareTo() assertThat(result).contains(new BigDecimal("1.00")); } + + @Test + public void bigDecimalCoercionInMethodArgument_shouldNotFailToBuild() { + // KIE-748 + String str = + "package org.drools.modelcompiler.bigdecimals\n" + + "import " + BDFact.class.getCanonicalName() + ";\n" + + "import static " + BigDecimalTest.class.getCanonicalName() + ".intToString;\n" + + "rule \"Rule 1a\"\n" + + " when\n" + + " BDFact( intToString(value2 - 1) == \"2\" )\n" + + " then\n" + + "end"; + + KieSession ksession = getKieSession(str); + + BDFact bdFact = new BDFact(); + bdFact.setValue2(new BigDecimal("3")); + + ksession.insert(bdFact); + + assertThat(ksession.fireAllRules()).isEqualTo(1); + } + + @Test + public void bigDecimalCoercionInNestedMethodArgument_shouldNotFailToBuild() { + // KIE-748 + String str = + "package org.drools.modelcompiler.bigdecimals\n" + + "import " + BDFact.class.getCanonicalName() + ";\n" + + "import static " + BigDecimalTest.class.getCanonicalName() + ".intToString;\n" + + "rule \"Rule 1a\"\n" + + " when\n" + + " BDFact( intToString(value1 * (value2 - 1)) == \"20\" )\n" + + " then\n" + + "end"; + + KieSession ksession = getKieSession(str); + + BDFact bdFact = new BDFact(); + bdFact.setValue1(new BigDecimal("10")); + bdFact.setValue2(new BigDecimal("3")); + + ksession.insert(bdFact); + + assertThat(ksession.fireAllRules()).isEqualTo(1); + } + + public static String intToString(int value) { + return Integer.toString(value); + } + + @Test + public void bindVariableToBigDecimalCoercion2Operands_shouldBindCorrectResult() { + bindVariableToBigDecimalCoercion("$var : (1000 * value1)"); + } + + @Test + public void bindVariableToBigDecimalCoercion3Operands_shouldBindCorrectResult() { + bindVariableToBigDecimalCoercion("$var : (100000 * value1 / 100)"); + } + + @Test + public void bindVariableToBigDecimalCoercion3OperandsWithParentheses_shouldBindCorrectResult() { + bindVariableToBigDecimalCoercion("$var : ((100000 * value1) / 100)"); + } + + private void bindVariableToBigDecimalCoercion(String binding) { + // KIE-775 + String str = + "package org.drools.modelcompiler.bigdecimals\n" + + "import " + BDFact.class.getCanonicalName() + ";\n" + + "global java.util.List result;\n" + + "rule R1\n" + + " when\n" + + " BDFact( " + binding + " )\n" + + " then\n" + + " result.add($var);\n" + + "end"; + + KieSession ksession = getKieSession(str); + List result = new ArrayList<>(); + ksession.setGlobal("result", result); + + BDFact bdFact = new BDFact(); + bdFact.setValue1(new BigDecimal("80")); + + ksession.insert(bdFact); + ksession.fireAllRules(); + + assertThat(result).contains(new BigDecimal("80000")); + } }