Skip to content

Commit

Permalink
[ES|QL] Optional named arguments for function in map (#118619)
Browse files Browse the repository at this point in the history
* MapExpression for functions
  • Loading branch information
fang-xing-esql authored Jan 16, 2025
1 parent 377d893 commit 11fbc8c
Show file tree
Hide file tree
Showing 35 changed files with 4,053 additions and 2,015 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/118619.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 118619
summary: Optional named arguments for function in map
area: EQL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.core.expression;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.PlanStreamInput;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

/**
* Represent a key-value pair.
*/
public class EntryExpression extends Expression {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
Expression.class,
"EntryExpression",
EntryExpression::readFrom
);

private final Expression key;

private final Expression value;

public EntryExpression(Source source, Expression key, Expression value) {
super(source, List.of(key, value));
this.key = key;
this.value = value;
}

private static EntryExpression readFrom(StreamInput in) throws IOException {
return new EntryExpression(
Source.readFrom((StreamInput & PlanStreamInput) in),
in.readNamedWriteable(Expression.class),
in.readNamedWriteable(Expression.class)
);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
out.writeNamedWriteable(key);
out.writeNamedWriteable(value);
}

@Override
public String getWriteableName() {
return ENTRY.name;
}

@Override
public Expression replaceChildren(List<Expression> newChildren) {
return new EntryExpression(source(), newChildren.get(0), newChildren.get(1));
}

@Override
protected NodeInfo<? extends Expression> info() {
return NodeInfo.create(this, EntryExpression::new, key, value);
}

public Expression key() {
return key;
}

public Expression value() {
return value;
}

@Override
public DataType dataType() {
return value.dataType();
}

@Override
public Nullability nullable() {
return Nullability.FALSE;
}

@Override
public int hashCode() {
return Objects.hash(key, value);
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}

EntryExpression other = (EntryExpression) obj;
return Objects.equals(key, other.key) && Objects.equals(value, other.value);
}

@Override
public String toString() {
return key.toString() + ":" + value.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public static List<NamedWriteableRegistry.Entry> expressions() {
entries.add(new NamedWriteableRegistry.Entry(Expression.class, e.name, in -> (Expression) e.reader.read(in)));
}
entries.add(Literal.ENTRY);
entries.addAll(mapExpressions());
return entries;
}

Expand All @@ -45,4 +46,8 @@ public static List<NamedWriteableRegistry.Entry> namedExpressions() {
public static List<NamedWriteableRegistry.Entry> attributes() {
return List.of(FieldAttribute.ENTRY, MetadataAttribute.ENTRY, ReferenceAttribute.ENTRY);
}

public static List<NamedWriteableRegistry.Entry> mapExpressions() {
return List.of(EntryExpression.ENTRY, MapExpression.ENTRY);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.esql.core.expression;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.core.util.PlanStreamInput;

import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED;

/**
* Represent a collect of key-value pairs.
*/
public class MapExpression extends Expression {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
Expression.class,
"MapExpression",
MapExpression::readFrom
);

private final List<EntryExpression> entryExpressions;

private final Map<Expression, Expression> map;

private final Map<Object, Expression> keyFoldedMap;

public MapExpression(Source source, List<Expression> entries) {
super(source, entries);
int entryCount = entries.size() / 2;
this.entryExpressions = new ArrayList<>(entryCount);
this.map = new LinkedHashMap<>(entryCount);
// create a map with key folded and source removed to make the retrieval of value easier
this.keyFoldedMap = new LinkedHashMap<>(entryCount);
for (int i = 0; i < entryCount; i++) {
Expression key = entries.get(i * 2);
Expression value = entries.get(i * 2 + 1);
entryExpressions.add(new EntryExpression(key.source(), key, value));
map.put(key, value);
if (key instanceof Literal l) {
this.keyFoldedMap.put(l.value(), value);
}
}
}

private static MapExpression readFrom(StreamInput in) throws IOException {
return new MapExpression(
Source.readFrom((StreamInput & PlanStreamInput) in),
in.readNamedWriteableCollectionAsList(Expression.class)
);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
source().writeTo(out);
out.writeNamedWriteableCollection(children());
}

@Override
public String getWriteableName() {
return ENTRY.name;
}

@Override
public MapExpression replaceChildren(List<Expression> newChildren) {
return new MapExpression(source(), newChildren);
}

@Override
protected NodeInfo<MapExpression> info() {
return NodeInfo.create(this, MapExpression::new, children());
}

public List<EntryExpression> entryExpressions() {
return entryExpressions;
}

public Map<Expression, Expression> map() {
return map;
}

public Map<Object, Expression> keyFoldedMap() {
return keyFoldedMap;
}

@Override
public Nullability nullable() {
return Nullability.FALSE;
}

@Override
public DataType dataType() {
return UNSUPPORTED;
}

@Override
public int hashCode() {
return Objects.hash(entryExpressions);
}

public Expression get(Object key) {
if (key instanceof Expression) {
return map.get(key);
} else {
// the key(literal) could be converted to BytesRef by ConvertStringToByteRef
return keyFoldedMap.containsKey(key) ? keyFoldedMap.get(key) : keyFoldedMap.get(new BytesRef(key.toString()));
}
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}

MapExpression other = (MapExpression) obj;
return Objects.equals(entryExpressions, other.entryExpressions);
}

@Override
public String toString() {
String str = entryExpressions.stream().map(String::valueOf).collect(Collectors.joining(", "));
return "{ " + str + " }";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,19 @@ private static String acceptedTypesForErrorMsg(String... acceptedTypes) {
return acceptedTypes[0];
}
}

public static TypeResolution isMapExpression(Expression e, String operationName, ParamOrdinal paramOrd) {
if (e instanceof MapExpression == false) {
return new TypeResolution(
format(
null,
"{}argument of [{}] must be a map expression, received [{}]",
paramOrd == null || paramOrd == DEFAULT ? "" : paramOrd.name().toLowerCase(Locale.ROOT) + " ",
operationName,
Expressions.name(e)
)
);
}
return TypeResolution.TYPE_RESOLVED;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public Set<String> getSupportedAnnotationTypes() {
"org.elasticsearch.injection.guice.Inject",
"org.elasticsearch.xpack.esql.expression.function.FunctionInfo",
"org.elasticsearch.xpack.esql.expression.function.Param",
"org.elasticsearch.xpack.esql.expression.function.MapParam",
"org.elasticsearch.rest.ServerlessScope",
"org.elasticsearch.xcontent.ParserConstructor",
"org.elasticsearch.core.UpdateForV9",
Expand Down
Loading

0 comments on commit 11fbc8c

Please sign in to comment.