Skip to content

Commit

Permalink
implement access of map values via field notation in the mvel compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
mariofusco committed Jan 17, 2024
1 parent 50550f7 commit 3c83719
Show file tree
Hide file tree
Showing 16 changed files with 217 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1057,4 +1057,9 @@ private static class RuleUnitMembers {
final Map<String, java.lang.reflect.Type> globals = new HashMap<>();
final Set<String> entryPoints = new HashSet<>();
}

@Override
public String toString() {
return pkg.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,13 @@
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.ObjectCreationExpr;
import com.github.javaparser.ast.expr.SimpleName;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.ast.expr.VariableDeclarationExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.stmt.ExpressionStmt;
import com.github.javaparser.ast.type.Type;
import org.drools.base.factmodel.ClassDefinition;
import org.drools.compiler.compiler.MissingDependencyError;
import org.drools.core.common.TruthMaintenanceSystemFactory;
import org.drools.base.factmodel.ClassDefinition;
import org.drools.model.BitMask;
import org.drools.model.bitmask.AllSetButLastBitMask;
import org.drools.model.codegen.execmodel.PackageModel;
Expand All @@ -60,10 +59,9 @@
import org.drools.model.codegen.execmodel.errors.InvalidExpressionErrorResult;
import org.drools.model.codegen.execmodel.errors.MvelCompilationError;
import org.drools.modelcompiler.consequence.DroolsImpl;
import org.drools.mvel.parser.ast.expr.DrlNameExpr;
import org.drools.mvelcompiler.CompiledBlockResult;
import org.drools.mvelcompiler.PreprocessCompiler;
import org.drools.mvelcompiler.MvelCompilerException;
import org.drools.mvelcompiler.PreprocessCompiler;
import org.drools.util.StringUtils;

import static com.github.javaparser.StaticJavaParser.parseExpression;
Expand Down Expand Up @@ -165,9 +163,6 @@ public MethodCallExpr createCall(String consequenceString, BlockStmt ruleVariabl
MethodCallExpr executeCall;
switch (context.getRuleDialect()) {
case JAVA:
if (context.arePrototypesAllowed()) {
rewriteConsequenceForPrototype(ruleConsequence, usedDeclarationInRHS);
}
rewriteReassignedDeclarations(ruleConsequence, usedDeclarationInRHS);
executeCall = executeCall(ruleVariablesBlock, ruleConsequence, usedDeclarationInRHS, onCall);
break;
Expand All @@ -184,38 +179,6 @@ public MethodCallExpr createCall(String consequenceString, BlockStmt ruleVariabl
return executeCall;
}

private void rewriteConsequenceForPrototype(BlockStmt ruleConsequence, Set<String> usedDeclarationInRHS) {
for (AssignExpr assignExpr : ruleConsequence.findAll(AssignExpr.class)) {
if (assignExpr.getTarget().isFieldAccessExpr()) {
FieldAccessExpr fieldAccessExpr = assignExpr.getTarget().asFieldAccessExpr();
String assignedVariable = getAssignedVariable(fieldAccessExpr);
if (assignedVariable != null && usedDeclarationInRHS.contains(assignedVariable) && context.isPrototypeDeclaration(assignedVariable)) {
MethodCallExpr setCall = new MethodCallExpr(new NameExpr(assignedVariable), "put");
setCall.addArgument(new StringLiteralExpr(fieldAccessExpr.getNameAsString()));
setCall.addArgument(assignExpr.getValue());
assignExpr.replace(setCall);
}
}
}

for (FieldAccessExpr fieldAccessExpr : ruleConsequence.findAll(FieldAccessExpr.class)) {
String assignedVariable = getAssignedVariable(fieldAccessExpr);
if ( assignedVariable != null && usedDeclarationInRHS.contains( assignedVariable ) && context.isPrototypeDeclaration( assignedVariable ) ) {
MethodCallExpr getCall = new MethodCallExpr( new NameExpr( assignedVariable ), "get" );
getCall.addArgument( new StringLiteralExpr( fieldAccessExpr.getNameAsString() ) );
fieldAccessExpr.replace(getCall);
}
}
}

private static String getAssignedVariable(FieldAccessExpr fieldAccessExpr) {
Expression scope = fieldAccessExpr.getScope();
if (scope instanceof DrlNameExpr drlName) {
return drlName.getName().toString();
}
return scope instanceof NameExpr ? scope.toString() : null;
}

private void replaceKcontext(BlockStmt ruleConsequence) {
ruleConsequence.findAll( Expression.class )
.stream()
Expand Down Expand Up @@ -254,7 +217,6 @@ private MethodCallExpr createExecuteCallMvel(String consequenceString, BlockStmt
replaceKcontext(compile.statementResults());
rewriteChannels(compile.statementResults());

rewriteConsequenceForPrototype(compile.statementResults(), usedDeclarationInRHS);
return executeCall(ruleVariablesBlock,
compile.statementResults(),
usedDeclarationInRHS,
Expand Down Expand Up @@ -348,13 +310,14 @@ private MethodCallExpr onCall(Collection<String> usedArguments) {
private String preprocessConsequence(String consequence) {
int modifyPos = StringUtils.indexOfOutOfQuotes(consequence, "modify");
int textBlockPos = StringUtils.indexOfOutOfQuotes(consequence, "\"\"\"");
Set<String> prototypes = context.getPrototypeDeclarations();

if (modifyPos < 0 && textBlockPos < 0) {
if (modifyPos < 0 && textBlockPos < 0 && prototypes.isEmpty()) {
return consequence;
}

PreprocessCompiler preprocessCompiler = new PreprocessCompiler();
CompiledBlockResult compile = preprocessCompiler.compile(addCurlyBracesToBlock(consequence));
CompiledBlockResult compile = preprocessCompiler.compile(addCurlyBracesToBlock(consequence), prototypes);

return printNode(compile.statementResults());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/
package org.drools.model.codegen.execmodel.generator;

import java.util.Map;
import java.util.Optional;

import com.github.javaparser.ast.expr.MethodCallExpr;
Expand All @@ -40,6 +39,6 @@ public interface DeclarationSpec {
void registerOnPackage(PackageModel packageModel, RuleContext context, BlockStmt ruleBlock);

default boolean isPrototypeDeclaration() {
return getDeclarationClass().isAssignableFrom(Map.class);
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -34,6 +35,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import com.github.javaparser.ast.body.Parameter;
import com.github.javaparser.ast.expr.Expression;
Expand Down Expand Up @@ -268,6 +270,12 @@ public boolean isPrototypeDeclaration(String id) {
return getDeclarationById(id).map(DeclarationSpec::isPrototypeDeclaration).orElse(false);
}

public Set<String> getPrototypeDeclarations() {
return arePrototypesAllowed() ?
scopedDeclarations.values().stream().filter(DeclarationSpec::isPrototypeDeclaration).map(DeclarationSpec::getBindingId).collect(Collectors.toSet()) :
Collections.emptySet();
}

public DeclarationSpec getDeclarationByIdWithException(String id) {
return getDeclarationById(id).orElseThrow(() -> new UnknownDeclarationException("Unknown declaration: " + id));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.FieldAccessExpr;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.ast.expr.VariableDeclarationExpr;
import com.github.javaparser.ast.stmt.ExpressionStmt;
import com.github.javaparser.ast.type.ClassOrInterfaceType;
Expand All @@ -54,6 +55,7 @@
import org.drools.mvelcompiler.ast.VariableDeclaratorTExpr;
import org.drools.mvelcompiler.context.Declaration;
import org.drools.mvelcompiler.context.MvelCompilerContext;
import org.drools.mvelcompiler.util.TypeUtils;
import org.drools.util.ClassUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -132,7 +134,9 @@ public TypedExpression visit(FieldAccessExpr n, Void arg) {
// a part of a larger FieldAccessExpr. e.g. [$p.address] of [$p.address.city]
return tryParseItAsGetter(n, fieldAccessScope)
.orElse(new UnalteredTypedExpression(n));
} else if(parentIsArrayAccessExpr(n)) {
} else if (fieldAccessScope.getType().map(TypeUtils::isMapAccessField).orElse(false) && parentIsAssignExpr(n)) {
return new MapPutExprT(fieldAccessScope, new StringLiteralExpr(n.getName().toString()), rhsOrNull(), fieldAccessScope.getType());
} else if (parentIsArrayAccessExpr(n)) {
return tryParseItAsMap(n, fieldAccessScope)
.map(Optional::of)
.orElseGet(() -> tryParseItAsSetter(n, fieldAccessScope, getRHSType()))
Expand Down Expand Up @@ -378,6 +382,10 @@ private boolean parentIsArrayAccessExpr(Node n) {
return n.getParentNode().filter(p -> p instanceof ArrayAccessExpr).isPresent();
}

private boolean parentIsAssignExpr(Node n) {
return n.getParentNode().filter(p -> p instanceof AssignExpr).isPresent();
}

private Class<?> getRHSType() {
return rhs
.flatMap(TypedExpression::getType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,19 @@

import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.body.VariableDeclarator;
import com.github.javaparser.ast.expr.AssignExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.FieldAccessExpr;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.ast.expr.TextBlockLiteralExpr;
import com.github.javaparser.ast.stmt.BlockStmt;
import org.drools.mvel.parser.MvelParser;
import org.drools.mvel.parser.ast.expr.DrlNameExpr;
import org.drools.mvel.parser.ast.expr.ModifyStatement;
import org.drools.mvelcompiler.ast.MapGetExprT;
import org.drools.mvelcompiler.ast.MapPutExprT;

import static com.github.javaparser.ast.NodeList.nodeList;

Expand All @@ -41,7 +47,7 @@ public class PreprocessCompiler {

private static final PreprocessPhase preprocessPhase = new PreprocessPhase();

public CompiledBlockResult compile(String mvelBlock) {
public CompiledBlockResult compile(String mvelBlock, Set<String> prototypes) {

BlockStmt mvelExpression = MvelParser.parseBlock(mvelBlock);

Expand All @@ -53,9 +59,9 @@ public CompiledBlockResult compile(String mvelBlock) {
StringLiteralExpr stringLiteralExpr = preprocessPhase.replaceTextBlockWithConcatenatedStrings(e);

parentNode.ifPresent(p -> {
if(p instanceof VariableDeclarator) {
if (p instanceof VariableDeclarator) {
((VariableDeclarator) p).setInitializer(stringLiteralExpr);
} else if(p instanceof MethodCallExpr) {
} else if (p instanceof MethodCallExpr) {
// """exampleString""".formatted("arg0", 2);
((MethodCallExpr) p).setScope(stringLiteralExpr);
}
Expand All @@ -77,6 +83,38 @@ public CompiledBlockResult compile(String mvelBlock) {
s.remove();
});

if (!prototypes.isEmpty()) {
rewriteConsequenceForPrototype(mvelExpression, prototypes);
}

return new CompiledBlockResult(mvelExpression.getStatements()).setUsedBindings(usedBindings);
}

private void rewriteConsequenceForPrototype(BlockStmt ruleConsequence, Set<String> prototypes) {
for (AssignExpr assignExpr : ruleConsequence.findAll(AssignExpr.class)) {
if (assignExpr.getTarget().isFieldAccessExpr()) {
FieldAccessExpr fieldAccessExpr = assignExpr.getTarget().asFieldAccessExpr();
String assignedVariable = getAssignedVariable(fieldAccessExpr);
if (prototypes.contains(assignedVariable)) {
assignExpr.replace(new MapPutExprT(new NameExpr(assignedVariable), new StringLiteralExpr(fieldAccessExpr.getNameAsString()),
assignExpr.getValue(), Optional.empty()).toJavaExpression());
}
}
}

for (FieldAccessExpr fieldAccessExpr : ruleConsequence.findAll(FieldAccessExpr.class)) {
String assignedVariable = getAssignedVariable(fieldAccessExpr);
if (prototypes.contains(assignedVariable)) {
fieldAccessExpr.replace( new MapGetExprT(new NameExpr(assignedVariable), fieldAccessExpr.getNameAsString() ).toJavaExpression() );
}
}
}

private static String getAssignedVariable(FieldAccessExpr fieldAccessExpr) {
Expression scope = fieldAccessExpr.getScope();
if (scope instanceof DrlNameExpr drlName) {
return drlName.getName().toString();
}
return scope instanceof NameExpr ? scope.toString() : null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -57,7 +56,6 @@
import org.drools.mvel.parser.ast.expr.BigIntegerLiteralExpr;
import org.drools.mvel.parser.ast.expr.DrlNameExpr;
import org.drools.mvel.parser.ast.expr.ListCreationLiteralExpression;
import org.drools.mvel.parser.ast.expr.ListCreationLiteralExpressionElement;
import org.drools.mvel.parser.ast.expr.MapCreationLiteralExpression;
import org.drools.mvel.parser.ast.visitor.DrlGenericVisitor;
import org.drools.mvelcompiler.ast.BigDecimalArithmeticExprT;
Expand All @@ -76,6 +74,7 @@
import org.drools.mvelcompiler.ast.ListExprT;
import org.drools.mvelcompiler.ast.LongLiteralExpressionT;
import org.drools.mvelcompiler.ast.MapExprT;
import org.drools.mvelcompiler.ast.MapGetExprT;
import org.drools.mvelcompiler.ast.ObjectCreationExpressionT;
import org.drools.mvelcompiler.ast.SimpleNameTExpr;
import org.drools.mvelcompiler.ast.StringLiteralExpressionT;
Expand All @@ -84,6 +83,7 @@
import org.drools.mvelcompiler.context.Declaration;
import org.drools.mvelcompiler.context.MvelCompilerContext;
import org.drools.mvelcompiler.util.MethodResolutionUtils;
import org.drools.mvelcompiler.util.TypeUtils;
import org.drools.mvelcompiler.util.VisitorContext;
import org.drools.util.ClassUtils;
import org.drools.util.MethodUtils.NullType;
Expand Down Expand Up @@ -236,6 +236,17 @@ private Optional<TypedExpression> asPropertyAccessorOfRootPattern(SimpleName n)
@Override
public TypedExpression visit(FieldAccessExpr n, VisitorContext arg) {
TypedExpression scope = n.getScope().accept(this, arg);
if (scope.getType().map(TypeUtils::isMapAccessField).orElse(false)) {
String key = n.getName().toString();

// "size" is an edge case and could mean both the size of the map and the value of the key "size".
// To keep backward compatibility it is necessary to assume that it is the size of the map,
// but this implies that at the moment it is not possible to read the value of the "size" key
// using the field access notation
if (!"size".equals(key)) {
return new MapGetExprT(scope, key);
}
}
return n.getName().accept(this, new VisitorContext(scope));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.drools.mvelcompiler.ast;

import java.lang.reflect.Type;
import java.util.Optional;

import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.StringLiteralExpr;

public class MapGetExprT implements TypedExpression {

private final Expression name;
private final String key;

public MapGetExprT(TypedExpression name, String key) {
this((Expression) name.toJavaExpression(), key);
}

public MapGetExprT(Expression name, String key) {
this.name = name;
this.key = key;
}

@Override
public Optional<Type> getType() {
return Optional.empty();
}

@Override
public Node toJavaExpression() {
return new MethodCallExpr(name, "get", NodeList.nodeList(new StringLiteralExpr(key)));
}
}
Loading

0 comments on commit 3c83719

Please sign in to comment.