Skip to content

Commit

Permalink
Merge pull request #408 from RumbleDB/FixCountOptimizationWithLetGroup
Browse files Browse the repository at this point in the history
Fix count optimization with let group
  • Loading branch information
ghislainfourny authored Dec 9, 2019
2 parents 9991f11 + bd43bd3 commit 8d33067
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 37 deletions.
103 changes: 82 additions & 21 deletions src/main/java/sparksoniq/spark/DataFrameUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,9 @@

public class DataFrameUtils {

private static ThreadLocal<byte[]> lastBytesCache = new ThreadLocal<byte[]>() {
@Override
protected byte[] initialValue() {
return null;
}
};
private static ThreadLocal<byte[]> lastBytesCache = ThreadLocal.withInitial(() -> null);

private static ThreadLocal<List<Item>> lastObjectItemCache = new ThreadLocal<List<Item>>() {
@Override
protected List<Item> initialValue() {
return null;
}
};
private static ThreadLocal<List<Item>> lastObjectItemCache = ThreadLocal.withInitial(() -> null);

public static void registerKryoClassesKryo(Kryo kryo) {
kryo.register(Item.class);
Expand Down Expand Up @@ -145,7 +135,7 @@ public static List<String> getColumnNames(
int duplicateVariableIndex,
Map<String, DynamicContext.VariableDependency> dependencies
) {
List<String> result = new ArrayList<String>();
List<String> result = new ArrayList<>();
String[] columnNames = inputSchema.fieldNames();
for (int columnIndex = 0; columnIndex < columnNames.length; columnIndex++) {
if (columnIndex == duplicateVariableIndex) {
Expand All @@ -159,6 +149,64 @@ public static List<String> getColumnNames(
return result;
}

/**
* @param inputSchema schema specifies the columns to be used in the query
* @param duplicateVariableIndex enables skipping a variable
* @param dependencies restriction of the results to within a specified set
* @return list of SQL column names in the schema
*/
public static List<String> getColumnNamesExceptPrecomputedCounts(
StructType inputSchema,
int duplicateVariableIndex,
Map<String, DynamicContext.VariableDependency> dependencies
) {
List<String> result = new ArrayList<>();
String[] columnNames = inputSchema.fieldNames();
for (int columnIndex = 0; columnIndex < columnNames.length; columnIndex++) {
if (columnIndex == duplicateVariableIndex) {
continue;
}
String var = columnNames[columnIndex];
if (
dependencies == null
|| (dependencies.containsKey(var)
&& !dependencies.get(var).equals(DynamicContext.VariableDependency.COUNT))
) {
result.add(columnNames[columnIndex]);
}
}
return result;
}

/**
* @param inputSchema schema specifies the columns to be used in the query
* @param duplicateVariableIndex enables skipping a variable
* @param dependencies restriction of the results to within a specified set
* @return list of SQL column names in the schema
*/
public static List<String> getPrecomputedCountColumnNames(
StructType inputSchema,
int duplicateVariableIndex,
Map<String, DynamicContext.VariableDependency> dependencies
) {
List<String> result = new ArrayList<>();
String[] columnNames = inputSchema.fieldNames();
for (int columnIndex = 0; columnIndex < columnNames.length; columnIndex++) {
if (columnIndex == duplicateVariableIndex) {
continue;
}
String var = columnNames[columnIndex];
if (
dependencies != null
&& dependencies.containsKey(var)
&& dependencies.get(var).equals(DynamicContext.VariableDependency.COUNT)
) {
result.add(columnNames[columnIndex]);
}
}
return result;
}

/**
* @param inputSchema schema specifies the columns to be used in the query
* @return list of SQL column names in the schema
Expand All @@ -174,12 +222,26 @@ public static void prepareDynamicContext(
List<String> columnNames,
List<List<Item>> deserializedParams
) {
// prepare dynamic context
for (int columnIndex = 0; columnIndex < columnNames.size(); columnIndex++) {
context.addVariableValue(columnNames.get(columnIndex), deserializedParams.get(columnIndex));
}
}

public static void prepareDynamicContext(
DynamicContext context,
List<String> binaryColumnNames,
List<String> countColumnNames,
List<List<Item>> deserializedParams,
List<Item> counts
) {
for (int columnIndex = 0; columnIndex < binaryColumnNames.size(); columnIndex++) {
context.addVariableValue(binaryColumnNames.get(columnIndex), deserializedParams.get(columnIndex));
}
for (int columnIndex = 0; columnIndex < countColumnNames.size(); columnIndex++) {
context.addVariableCount(countColumnNames.get(columnIndex), counts.get(columnIndex));
}
}

/**
* @param columnNames schema specifies the columns to be used in the query
* @param trailingComma boolean field to have a trailing comma
Expand Down Expand Up @@ -274,12 +336,11 @@ public static String getGroupbyProjectSQL(
return queryColumnString.toString();
}

public static Object deserializeByteArray(byte[] toDeserialize, Kryo kryo, Input input) {
private static Object deserializeByteArray(byte[] toDeserialize, Kryo kryo, Input input) {
byte[] bytes = lastBytesCache.get();
if (bytes != null) {
if (Arrays.equals(bytes, toDeserialize)) {
List<Item> deserializedParam = lastObjectItemCache.get();
return deserializedParam;
return lastObjectItemCache.get();
}
}
input.setBuffer(toDeserialize);
Expand Down Expand Up @@ -325,7 +386,7 @@ public static Row reserializeRowWithNewData(
public static List<Item> deserializeRowField(Row row, int columnIndex, Kryo kryo, Input input) {
Object o = row.get(columnIndex);
if (o instanceof Long) {
List<Item> result = new ArrayList<Item>(1);
List<Item> result = new ArrayList<>(1);
result.add(ItemFactory.getInstance().createIntegerItem(((Long) o).intValue()));
return result;
} else {
Expand Down Expand Up @@ -376,12 +437,12 @@ public static Dataset<Row> zipWithIndex(Dataset<Row> df, Long offset, String ind
.collect();
Row[] partitionOffsetsArray = ((Row[]) partitionOffsetsObject);
Map<Integer, Long> partitionOffsets = new HashMap<>();
for (int i = 0; i < partitionOffsetsArray.length; i++) {
partitionOffsets.put(partitionOffsetsArray[i].getInt(0), partitionOffsetsArray[i].getLong(1));
for (Row row : partitionOffsetsArray) {
partitionOffsets.put(row.getInt(0), row.getLong(1));
}

UserDefinedFunction getPartitionOffset = udf(
(partitionId) -> partitionOffsets.get((Integer) partitionId),
(partitionId) -> partitionOffsets.get(partitionId),
DataTypes.LongType
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,26 +244,37 @@ public Dataset<Row> getDataFrame(
int duplicateVariableIndex = columnNames.indexOf(newVariableName);

List<String> allColumns = DataFrameUtils.getColumnNames(inputSchema, duplicateVariableIndex, null);
List<String> UDFcolumns = DataFrameUtils.getColumnNames(inputSchema, -1, _dependencies);
List<String> UDFbinarycolumns = DataFrameUtils.getColumnNamesExceptPrecomputedCounts(
inputSchema,
-1,
_dependencies
);
List<String> UDFlongcolumns = DataFrameUtils.getPrecomputedCountColumnNames(
inputSchema,
-1,
_dependencies
);

df.sparkSession()
.udf()
.register(
"letClauseUDF",
new LetClauseUDF(newVariableExpression, UDFcolumns),
new LetClauseUDF(newVariableExpression, UDFbinarycolumns, UDFlongcolumns),
DataTypes.BinaryType
);

String selectSQL = DataFrameUtils.getSQL(allColumns, true);
String udfSQL = DataFrameUtils.getSQL(UDFcolumns, false);
String udfBinarySQL = DataFrameUtils.getSQL(UDFbinarycolumns, false);
String udfLongSQL = DataFrameUtils.getSQL(UDFlongcolumns, false);

df.createOrReplaceTempView("input");
df = df.sparkSession()
.sql(
String.format(
"select %s letClauseUDF(array(%s)) as `%s` from input",
"select %s letClauseUDF(array(%s), array(%s)) as `%s` from input",
selectSQL,
udfSQL,
udfBinarySQL,
udfLongSQL,
newVariableName
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,26 +163,33 @@ public Dataset<Row> getDataFrame(
int duplicateVariableIndex = Arrays.asList(inputSchema.fieldNames()).indexOf(_variableName);

List<String> allColumns = DataFrameUtils.getColumnNames(inputSchema, duplicateVariableIndex, null);
List<String> UDFcolumns = DataFrameUtils.getColumnNames(inputSchema, -1, _dependencies);
List<String> UDFbinarycolumns = DataFrameUtils.getColumnNamesExceptPrecomputedCounts(
inputSchema,
-1,
_dependencies
);
List<String> UDFlongcolumns = DataFrameUtils.getPrecomputedCountColumnNames(inputSchema, -1, _dependencies);

df.sparkSession()
.udf()
.register(
"letClauseUDF",
new LetClauseUDF(_expression, UDFcolumns),
new LetClauseUDF(_expression, UDFbinarycolumns, UDFlongcolumns),
DataTypes.BinaryType
);

String selectSQL = DataFrameUtils.getSQL(allColumns, true);
String udfSQL = DataFrameUtils.getSQL(UDFcolumns, false);
String udfBinarySQL = DataFrameUtils.getSQL(UDFbinarycolumns, false);
String udfLongSQL = DataFrameUtils.getSQL(UDFlongcolumns, false);

df.createOrReplaceTempView("input");
df = df.sparkSession()
.sql(
String.format(
"select %s letClauseUDF(array(%s)) as `%s` from input",
"select %s letClauseUDF(array(%s), array(%s)) as `%s` from input",
selectSQL,
udfSQL,
udfBinarySQL,
udfLongSQL,
_variableName
)
);
Expand Down
34 changes: 28 additions & 6 deletions src/main/java/sparksoniq/spark/udf/LetClauseUDF.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
package sparksoniq.spark.udf;

import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.api.java.UDF2;
import org.rumbledb.api.Item;

import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;

import scala.collection.mutable.WrappedArray;
import sparksoniq.jsoniq.item.ItemFactory;
import sparksoniq.jsoniq.runtime.iterator.RuntimeIterator;
import sparksoniq.semantics.DynamicContext;
import sparksoniq.spark.DataFrameUtils;
Expand All @@ -36,14 +38,16 @@
import java.util.ArrayList;
import java.util.List;

public class LetClauseUDF implements UDF1<WrappedArray<byte[]>, byte[]> {
public class LetClauseUDF implements UDF2<WrappedArray<byte[]>, WrappedArray<Long>, byte[]> {

private static final long serialVersionUID = 1L;
private RuntimeIterator _expression;

List<String> _columnNames;
List<String> _binaryColumnNames;
List<String> _longColumnNames;

private List<List<Item>> _deserializedParams;
private List<Item> _longParams;
private DynamicContext _context;
private List<Item> _nextResult;

Expand All @@ -53,11 +57,13 @@ public class LetClauseUDF implements UDF1<WrappedArray<byte[]>, byte[]> {

public LetClauseUDF(
RuntimeIterator expression,
List<String> columnNames
List<String> binaryColumnNames,
List<String> longColumnNames
) {
_expression = expression;

_deserializedParams = new ArrayList<>();
_longParams = new ArrayList<>();
_context = new DynamicContext();
_nextResult = new ArrayList<>();

Expand All @@ -67,19 +73,35 @@ public LetClauseUDF(
_output = new Output(128, -1);
_input = new Input();

_columnNames = columnNames;
_binaryColumnNames = binaryColumnNames;
_longColumnNames = longColumnNames;
}


@Override
public byte[] call(WrappedArray<byte[]> wrappedParameters) {
public byte[] call(WrappedArray<byte[]> wrappedParameters, WrappedArray<Long> wrappedParametersLong) {
_deserializedParams.clear();
_longParams.clear();
_context.removeAllVariables();
_nextResult.clear();

DataFrameUtils.deserializeWrappedParameters(wrappedParameters, _deserializedParams, _kryo, _input);

DataFrameUtils.prepareDynamicContext(_context, _columnNames, _deserializedParams);
// Long parameters correspond to pre-computed counts, when a materialization of the
// actual sequence was avoided upfront.
Object[] longParams = (Object[]) wrappedParametersLong.array();
for (Object longParam : longParams) {
Item count = ItemFactory.getInstance().createIntegerItem(((Long) longParam).intValue());
_longParams.add(count);
}

DataFrameUtils.prepareDynamicContext(
_context,
_binaryColumnNames,
_longColumnNames,
_deserializedParams,
_longParams
);

// apply expression in the dynamic context
_expression.open(_context);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
(:JIQS: ShouldRun; Output="(2, 1, 1, 1)" :)
for $i in json-file("./src/main/resources/queries/conf-ex.json")
group by $y := $i.country, $t := $i.target
let $c := count($i)
order by $c descending
return $c

0 comments on commit 8d33067

Please sign in to comment.