Skip to content

Commit

Permalink
[KIE-775] drools executable-model fails with a bind variable to a cal…
Browse files Browse the repository at this point in the history
…culation result of int and BigDecimal (apache#5636)

(cherry picked from commit fe2dbac)
  • Loading branch information
tkobayas authored and baldimir committed Feb 28, 2024
1 parent 8d8a7f0 commit edd542d
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ public ArithmeticCoercedExpression(TypedExpression left, TypedExpression right,
this.operator = operator;
}

/*
* This coercion only deals with String vs Numeric types.
* BigDecimal arithmetic operation is handled by ExpressionTyper.convertArithmeticBinaryToMethodCall()
*/
public ArithmeticCoercedExpressionResult coerce() {

if (!requiresCoercion()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@
import static org.drools.model.codegen.execmodel.generator.ConstraintUtil.GREATER_THAN_PREFIX;
import static org.drools.model.codegen.execmodel.generator.ConstraintUtil.LESS_OR_EQUAL_PREFIX;
import static org.drools.model.codegen.execmodel.generator.ConstraintUtil.LESS_THAN_PREFIX;
import static org.drools.model.codegen.execmodel.generator.expressiontyper.ExpressionTyper.convertArithmeticBinaryToMethodCall;
import static org.drools.model.codegen.execmodel.generator.expressiontyper.ExpressionTyper.getBinaryTypeAfterConversion;
import static org.drools.model.codegen.execmodel.generator.expressiontyper.ExpressionTyper.shouldConvertArithmeticBinaryToMethodCall;
import static org.drools.util.StringUtils.lcFirstForBean;
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.THIS_PLACEHOLDER;
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.createConstraintCompiler;
Expand Down Expand Up @@ -196,11 +199,23 @@ private void logWarnIfNoReactOnCausedByVariableFromDifferentPattern(DrlxParseRes
}

private void addDeclaration(DrlxExpression drlx, SingleDrlxParseSuccess singleResult, String bindId) {
DeclarationSpec decl = context.addDeclaration( bindId, singleResult.getLeftExprTypeBeforeCoercion() );
DeclarationSpec decl = context.addDeclaration(bindId, getDeclarationType(drlx, singleResult));
if (drlx.getExpr() instanceof NameExpr) {
decl.setBoundVariable( PrintUtil.printNode(drlx.getExpr()) );
} else if (drlx.getExpr() instanceof BinaryExpr) {
decl.setBoundVariable(PrintUtil.printNode(drlx.getExpr().asBinaryExpr().getLeft()));
Expression leftMostExpression = getLeftMostExpression(drlx.getExpr().asBinaryExpr());
decl.setBoundVariable(PrintUtil.printNode(leftMostExpression));
if (singleResult.getExpr() instanceof MethodCallExpr) {
// BinaryExpr was converted to MethodCallExpr. Create a TypedExpression for the leftmost expression of the BinaryExpr
ExpressionTyperContext expressionTyperContext = new ExpressionTyperContext();
ExpressionTyper expressionTyper = new ExpressionTyper(context, singleResult.getPatternType(), bindId, false, expressionTyperContext);
TypedExpressionResult leftTypedExpressionResult = expressionTyper.toTypedExpression(leftMostExpression);
Optional<TypedExpression> optLeft = leftTypedExpressionResult.getTypedExpression();
if (!optLeft.isPresent()) {
throw new IllegalStateException("Cannot create TypedExpression for " + drlx.getExpr().asBinaryExpr().getLeft());
}
singleResult.setBoundExpr(optLeft.get());
}
}
decl.setBelongingPatternDescr(context.getCurrentPatternDescr());
singleResult.setExprBinding( bindId );
Expand All @@ -210,6 +225,24 @@ private void addDeclaration(DrlxExpression drlx, SingleDrlxParseSuccess singleRe
}
}

private static Class<?> getDeclarationType(DrlxExpression drlx, SingleDrlxParseSuccess singleResult) {
if (drlx.getBind() != null && drlx.getExpr() instanceof EnclosedExpr) {
// in case of enclosed, bind type should be the calculation result type
// If drlx.getBind() == null, a bind variable is inside the enclosed expression, leave it to the default behavior
return (Class<?>)singleResult.getExprType();
} else {
return singleResult.getLeftExprTypeBeforeCoercion();
}
}

private Expression getLeftMostExpression(BinaryExpr binaryExpr) {
Expression left = binaryExpr.getLeft();
if (left instanceof BinaryExpr) {
return getLeftMostExpression((BinaryExpr) left);
}
return left;
}

/*
This is the entry point for Constraint Transformation from a parsed MVEL constraint
to a Java Expression
Expand Down Expand Up @@ -656,17 +689,16 @@ private DrlxParseResult parseBinaryExpr(BinaryExpr binaryExpr, Class<?> patternT

Expression combo;

boolean arithmeticExpr = ARITHMETIC_OPERATORS.contains(operator);
boolean isBetaConstraint = right.getExpression() != null && hasDeclarationFromOtherPattern( expressionTyperContext );
boolean requiresSplit = operator == BinaryExpr.Operator.AND && binaryExpr.getRight() instanceof HalfBinaryExpr && !isBetaConstraint;

Type exprType = isBooleanOperator( operator ) ? boolean.class : left.getType();

if (equalityExpr) {
combo = getEqualityExpression( left, right, operator ).expression;
} else if (arithmeticExpr && (left.isBigDecimal())) {
ConstraintCompiler constraintCompiler = createConstraintCompiler(this.context, of(patternType));
CompiledExpressionResult compiledExpressionResult = constraintCompiler.compileExpression(binaryExpr);

combo = compiledExpressionResult.getExpression();
} else if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) {
combo = convertArithmeticBinaryToMethodCall(binaryExpr, of(patternType), this.context);
exprType = getBinaryTypeAfterConversion(left.getType(), right.getType());
} else {
if (left.getExpression() == null || right.getExpression() == null) {
return new DrlxParseFail(new ParseExpressionErrorResult(drlxExpr));
Expand Down Expand Up @@ -694,7 +726,7 @@ private DrlxParseResult parseBinaryExpr(BinaryExpr binaryExpr, Class<?> patternT
constraintType = Index.ConstraintType.FORALL_SELF_JOIN;
}

return new SingleDrlxParseSuccess(patternType, bindingId, combo, isBooleanOperator( operator ) ? boolean.class : left.getType())
return new SingleDrlxParseSuccess(patternType, bindingId, combo, exprType)
.setDecodeConstraintType( constraintType )
.setUsedDeclarations( expressionTyperContext.getUsedDeclarations() )
.setUsedDeclarationsOnLeft( usedDeclarationsOnLeft )
Expand Down Expand Up @@ -1007,4 +1039,8 @@ private Optional<DrlxParseFail> convertBigDecimalArithmetic(MethodCallExpr metho
}
return Optional.empty();
}

public static boolean isArithmeticOperator(BinaryExpr.Operator operator) {
return ARITHMETIC_OPERATORS.contains(operator);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.lang.reflect.Modifier;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.TypeVariable;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
Expand Down Expand Up @@ -89,6 +90,8 @@
import org.drools.mvel.parser.ast.expr.OOPathExpr;
import org.drools.mvel.parser.ast.expr.PointFreeExpr;
import org.drools.mvel.parser.printer.PrintUtil;
import org.drools.mvelcompiler.CompiledExpressionResult;
import org.drools.mvelcompiler.ConstraintCompiler;
import org.drools.mvelcompiler.util.BigDecimalArgumentCoercion;
import org.drools.util.MethodUtils;
import org.drools.util.TypeResolver;
Expand All @@ -99,6 +102,7 @@
import static java.util.Optional.empty;
import static java.util.Optional.of;
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.THIS_PLACEHOLDER;
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.createConstraintCompiler;
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.findRootNodeViaParent;
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.getClassFromContext;
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.getClassFromType;
Expand All @@ -113,6 +117,7 @@
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.toStringLiteral;
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.transformDrlNameExprToNameExpr;
import static org.drools.model.codegen.execmodel.generator.DrlxParseUtil.trasformHalfBinaryToBinary;
import static org.drools.model.codegen.execmodel.generator.drlxparse.ConstraintParser.isArithmeticOperator;
import static org.drools.model.codegen.execmodel.generator.expressiontyper.FlattenScope.flattenScope;
import static org.drools.model.codegen.execmodel.generator.expressiontyper.FlattenScope.transformFullyQualifiedInlineCastExpr;
import static org.drools.mvel.parser.MvelParser.parseType;
Expand Down Expand Up @@ -229,7 +234,14 @@ private Optional<TypedExpression> toTypedExpressionRec(Expression drlxExpr) {
right = coerced.getCoercedRight();

final BinaryExpr combo = new BinaryExpr(left.getExpression(), right.getExpression(), operator);
return of(new TypedExpression(combo, left.getType()));

if (shouldConvertArithmeticBinaryToMethodCall(operator, left.getType(), right.getType())) {
Expression expression = convertArithmeticBinaryToMethodCall(combo, of(typeCursor), ruleContext);
java.lang.reflect.Type binaryType = getBinaryTypeAfterConversion(left.getType(), right.getType());
return of(new TypedExpression(expression, binaryType));
} else {
return of(new TypedExpression(combo, left.getType()));
}
}

if (drlxExpr instanceof HalfBinaryExpr) {
Expand Down Expand Up @@ -800,7 +812,38 @@ private TypedExpressionCursor binaryExpr(BinaryExpr binaryExpr) {
TypedExpression rightTypedExpression = right.getTypedExpression()
.orElseThrow(() -> new NoSuchElementException("TypedExpressionResult doesn't contain TypedExpression!"));
binaryExpr.setRight(rightTypedExpression.getExpression());
return new TypedExpressionCursor(binaryExpr, getBinaryType(leftTypedExpression, rightTypedExpression, binaryExpr.getOperator()));
if (shouldConvertArithmeticBinaryToMethodCall(binaryExpr.getOperator(), leftTypedExpression.getType(), rightTypedExpression.getType())) {
Expression compiledExpression = convertArithmeticBinaryToMethodCall(binaryExpr, leftTypedExpression.getOriginalPatternType(), ruleContext);
java.lang.reflect.Type binaryType = getBinaryTypeAfterConversion(leftTypedExpression.getType(), rightTypedExpression.getType());
return new TypedExpressionCursor(compiledExpression, binaryType);
} else {
java.lang.reflect.Type binaryType = getBinaryType(leftTypedExpression, rightTypedExpression, binaryExpr.getOperator());
return new TypedExpressionCursor(binaryExpr, binaryType);
}
}

/*
* Converts arithmetic binary expression (including coercion) to method call using ConstraintCompiler.
* This method can be generic, so we may centralize the calls in drools-model
*/
public static Expression convertArithmeticBinaryToMethodCall(BinaryExpr binaryExpr, Optional<Class<?>> originalPatternType, RuleContext ruleContext) {
ConstraintCompiler constraintCompiler = createConstraintCompiler(ruleContext, originalPatternType);
CompiledExpressionResult compiledExpressionResult = constraintCompiler.compileExpression(printNode(binaryExpr));
return compiledExpressionResult.getExpression();
}

/*
* BigDecimal arithmetic operations should be converted to method calls. We may also apply this to BigInteger.
*/
public static boolean shouldConvertArithmeticBinaryToMethodCall(BinaryExpr.Operator operator, java.lang.reflect.Type leftType, java.lang.reflect.Type rightType) {
return isArithmeticOperator(operator) && (leftType.equals(BigDecimal.class) || rightType.equals(BigDecimal.class));
}

/*
* After arithmetic to method call conversion, BigDecimal should take precedence regardless of left or right. We may also apply this to BigInteger.
*/
public static java.lang.reflect.Type getBinaryTypeAfterConversion(java.lang.reflect.Type leftType, java.lang.reflect.Type rightType) {
return (leftType.equals(BigDecimal.class) || rightType.equals(BigDecimal.class)) ? BigDecimal.class : leftType;
}

private java.lang.reflect.Type getBinaryType(TypedExpression leftTypedExpression, TypedExpression rightTypedExpression, Operator operator) {
Expand Down Expand Up @@ -907,6 +950,9 @@ private void promoteBigDecimalParameters(MethodCallExpr methodCallExpr, Class[]
Expression argumentExpression = methodCallExpr.getArgument(i);

if (argumentType != actualArgumentType) {
// unbind the original argumentExpression first, otherwise setArgument() will remove the argumentExpression from coercedExpression.childrenNodes
// It will result in failing to find DrlNameExpr in AST at DrlsParseUtil.transformDrlNameExprToNameExpr()
methodCallExpr.replace(argumentExpression, new NameExpr("placeholder"));
Expression coercedExpression = new BigDecimalArgumentCoercion().coercedArgument(argumentType, actualArgumentType, argumentExpression);
methodCallExpr.setArgument(i, coercedExpression);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -720,4 +720,96 @@ public void bigDecimalEqualityWithDifferentScale_shouldBeEqual() {
// BigDecimal("1.0") and BigDecimal("1.00") are considered as equal because exec-model uses EvaluationUtil.equals() which is based on compareTo()
assertThat(result).contains(new BigDecimal("1.00"));
}

@Test
public void bigDecimalCoercionInMethodArgument_shouldNotFailToBuild() {
// KIE-748
String str =
"package org.drools.modelcompiler.bigdecimals\n" +
"import " + BDFact.class.getCanonicalName() + ";\n" +
"import static " + BigDecimalTest.class.getCanonicalName() + ".intToString;\n" +
"rule \"Rule 1a\"\n" +
" when\n" +
" BDFact( intToString(value2 - 1) == \"2\" )\n" +
" then\n" +
"end";

KieSession ksession = getKieSession(str);

BDFact bdFact = new BDFact();
bdFact.setValue2(new BigDecimal("3"));

ksession.insert(bdFact);

assertThat(ksession.fireAllRules()).isEqualTo(1);
}

@Test
public void bigDecimalCoercionInNestedMethodArgument_shouldNotFailToBuild() {
// KIE-748
String str =
"package org.drools.modelcompiler.bigdecimals\n" +
"import " + BDFact.class.getCanonicalName() + ";\n" +
"import static " + BigDecimalTest.class.getCanonicalName() + ".intToString;\n" +
"rule \"Rule 1a\"\n" +
" when\n" +
" BDFact( intToString(value1 * (value2 - 1)) == \"20\" )\n" +
" then\n" +
"end";

KieSession ksession = getKieSession(str);

BDFact bdFact = new BDFact();
bdFact.setValue1(new BigDecimal("10"));
bdFact.setValue2(new BigDecimal("3"));

ksession.insert(bdFact);

assertThat(ksession.fireAllRules()).isEqualTo(1);
}

public static String intToString(int value) {
return Integer.toString(value);
}

@Test
public void bindVariableToBigDecimalCoercion2Operands_shouldBindCorrectResult() {
bindVariableToBigDecimalCoercion("$var : (1000 * value1)");
}

@Test
public void bindVariableToBigDecimalCoercion3Operands_shouldBindCorrectResult() {
bindVariableToBigDecimalCoercion("$var : (100000 * value1 / 100)");
}

@Test
public void bindVariableToBigDecimalCoercion3OperandsWithParentheses_shouldBindCorrectResult() {
bindVariableToBigDecimalCoercion("$var : ((100000 * value1) / 100)");
}

private void bindVariableToBigDecimalCoercion(String binding) {
// KIE-775
String str =
"package org.drools.modelcompiler.bigdecimals\n" +
"import " + BDFact.class.getCanonicalName() + ";\n" +
"global java.util.List result;\n" +
"rule R1\n" +
" when\n" +
" BDFact( " + binding + " )\n" +
" then\n" +
" result.add($var);\n" +
"end";

KieSession ksession = getKieSession(str);
List<BigDecimal> result = new ArrayList<>();
ksession.setGlobal("result", result);

BDFact bdFact = new BDFact();
bdFact.setValue1(new BigDecimal("80"));

ksession.insert(bdFact);
ksession.fireAllRules();

assertThat(result).contains(new BigDecimal("80000"));
}
}

0 comments on commit edd542d

Please sign in to comment.