Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP / experiment] Specialization of std.foldl(std.mergePatch, arr, target) #254

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sjsonnet/src/sjsonnet/Settings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class Settings(
val strictInheritedAssertions: Boolean = false,
val strictSetOperations: Boolean = false,
val throwErrorForInvalidSets: Boolean = false,
val disableBuiltinSpecialization: Boolean = false,
val disableStaticApplyForBuiltinFunctions: Boolean = false
)

object Settings {
Expand Down
13 changes: 9 additions & 4 deletions sjsonnet/src/sjsonnet/StaticOptimizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ class StaticOptimizer(
}

private def tryStaticApply(pos: Position, f: Val.Builtin, args: Array[Expr]): Expr = {
if(f.staticSafe && args.forall(_.isInstanceOf[Val])) {
if (ev.settings.disableStaticApplyForBuiltinFunctions) null
else if(f.staticSafe && args.forall(_.isInstanceOf[Val])) {
val vargs = args.map(_.asInstanceOf[Val])
try f.apply(vargs, null, pos)(ev).asInstanceOf[Expr] catch { case _: Exception => return null }
} else null
Expand All @@ -146,9 +147,13 @@ class StaticOptimizer(
case newArgs =>
tryStaticApply(pos, f, newArgs) match {
case null =>
val (f2, rargs) = f.specialize(newArgs) match {
case null => (f, newArgs)
case (f2, a2) => (f2, a2)
val (f2, rargs) = if (ev.settings.disableBuiltinSpecialization) {
(f, newArgs)
} else {
f.specialize(newArgs) match {
case null => (f, newArgs)
case (f2, a2) => (f2, a2)
}
}
val alen = rargs.length
f2 match {
Expand Down
173 changes: 173 additions & 0 deletions sjsonnet/src/sjsonnet/Std.scala
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,181 @@ class Std(private val additionalNativeFunctions: Map[String, Val.Builtin] = Map.

case _ => Error.fail("Cannot call foldl on " + arr.prettyName)
}
}

override def specialize(args: Array[Expr]): (Val.Builtin, Array[Expr]) = {
try {
args match {
// Specialize std.foldl(std.mergePatch, arr, target)
case Array(f: Val.Builtin, arr, target: Val.Obj) if f.functionName == "mergePatch" =>
(FoldlMergePatch, Array(arr, target))
}
} catch {
case _: Exception => null
}
}
}

/**
* Optimized equivalent of `std.foldl(std.mergePatch, arr, target)`
*/
private object FoldlMergePatch extends Val.Builtin2("mergePatchAll", "array", "target") {

/**
* Recursively process an object by:
* Removing any fields explicitly set to null.
* Removing any non-visible fields.
* Stripping the `add` modifier off of any fields with it.
*/
private def cleanObject(obj: Val.Obj, ev: EvalScope): Val.Obj = {
val visibleKeys = obj.visibleKeyNames
val newFields: Array[Val] = new Array(visibleKeys.length)
var i = 0
var updateNeeded: Boolean = (visibleKeys.length != obj.allKeysSize) || obj.hasAPlusField
while (i < visibleKeys.length) {
val key = visibleKeys(i)
val value = obj.value(key, obj.pos.noOffset, obj)(ev)
if (value.isInstanceOf[Val.Null]) {
newFields(i) = null
updateNeeded = true
} else {
val newValue = value match {
case obj: Val.Obj => cleanObject(obj, ev)
case _ => value
}
newFields(i) = newValue
if (newValue ne value) updateNeeded = true
}
i += 1
}
if (updateNeeded) {
val newFieldsMap = new util.LinkedHashMap[String, Val.Obj.Member](visibleKeys.length)
i = 0
while (i < visibleKeys.length) {
val key = visibleKeys(i)
val value = newFields(i)
if (value != null) {
newFieldsMap.put(key, createMember(value))
}
i += 1
}
new Val.Obj(obj.pos, newFieldsMap, false, null, null)
} else {
obj
}
}

private def createMember(v: Val) = new Val.Obj.Member(false, Visibility.Unhide) {
def invoke(self: Val.Obj, sup: Val.Obj, fs: FileScope, ev: EvalScope): Val = v
}

// Placeholder to represent absence of a value. We use this instead of `null` because
// LinkedHashMap.putIfAbsent still affects insertion order if the existing value is null.
private[this] val nullCanary = new Val.Obj.ConstMember(false, Visibility.Normal, Val.Null(pos))

def evalRhs(arr: Val, target: Val, ev: EvalScope, pos: Position): Val = {
// Here, `objectSize` is the number of valid entries in the `objects` array.
// This is an optimization to avoid having to trim or resize intermediate arrays.
def recMerge(target: Val, objects: Array[Val.Obj], objectsSize: Int): Val.Obj = {
// Determine an upper bound of the final key set (only a bound because a key
// might end up being removed and we can only know that after further processing).
// We need an `outputFields` LinkedHashMap anyways, so we'll first use it to
// collect the distinct fields and then will update it to either populate the members
// or remove fields that were later determined to be unused.
val outputFields = new util.LinkedHashMap[String, Val.Obj.Member]()
target match {
case t: Val.Obj =>
t.visibleKeyNames.foreach(k => outputFields.putIfAbsent(k, nullCanary))
case _ =>
}
var idx = 0
while (idx < objectsSize) {
objects(idx).visibleKeyNames.foreach(k => outputFields.putIfAbsent(k, nullCanary))
idx += 1
}

// Perform the merge for each key:
val keysIter = outputFields.keySet().iterator()
val objValues = new Array[Val.Obj](objectsSize)
while (keysIter.hasNext) {
val key = keysIter.next()
val targetValue: Val = target match {
case targetObj: Val.Obj if targetObj.containsVisibleKey(key) =>
targetObj.valueRaw(key, targetObj, pos)(ev)
case _ =>
null
}

var lastValue: Val = null
var objCount = 0
var i = 0

// Loop over the patches, determining either the final non-object
// value which overwrites the target key or determining the subset
// of patches which contribute to the final value.
while (i < objectsSize) {
val obj = objects(i)
if (obj.containsVisibleKey(key)) {
lastValue = obj.valueRaw(key, obj, pos)(ev)
lastValue match {
case _: Val.Obj =>
objValues(objCount) = lastValue.asInstanceOf[Val.Obj]
objCount += 1
case _ =>
// Got either a Null or a non-Obj, but in either case we
// won't use any of the earlier values for this key in the
// merged result since they would be overwritten and discarded
// at this step.
objCount = 0
}
}
i += 1
}
val removeField = lastValue.isInstanceOf[Val.Null]
if (removeField) {
keysIter.remove()
} else {
val finalValue = {
if (objCount > 1) recMerge(targetValue, objValues, objCount)
else if (objCount == 1) cleanObject(objValues(0), ev)
else if (lastValue != null) lastValue
else targetValue
}
outputFields.replace(key, createMember(finalValue))
}
}
new Val.Obj(pos, outputFields, false, null, null)
}

arr match {
case arr: Val.Arr =>
val length = arr.length
if (length == 0) {
target
} else {
val objects = new Array[Val.Obj](length)
var arrIdx = 0
var objectsIdx = 0
while (arrIdx < length) {
arr.force(arrIdx) match {
case obj: Val.Obj =>
objects(objectsIdx) = obj
objectsIdx += 1
case _ =>
objectsIdx = 0
}
arrIdx += 1
}
if (objectsIdx == 0) {
// The last element is a non-object, so it overwrites everything.
arr.force(length - 1)
} else {
recMerge(target, objects, objectsIdx)
}
}

case v => Error.fail(s"Expected array, got ${v.prettyName}", pos)(ev)
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions sjsonnet/src/sjsonnet/Val.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import sjsonnet.Expr.Params
import scala.annotation.tailrec
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConverters._
import scala.reflect.ClassTag

/**
Expand Down Expand Up @@ -227,6 +228,9 @@ object Val{

@inline def containsVisibleKey(k: String): Boolean = getAllKeys.get(k) == java.lang.Boolean.FALSE

def allKeysSize: Int = getAllKeys.size()
def hasAPlusField: Boolean = getValue0.values().asScala.exists(_.add)

lazy val allKeyNames: Array[String] = getAllKeys.keySet().toArray(new Array[String](getAllKeys.size()))

lazy val visibleKeyNames: Array[String] = if(static) allKeyNames else {
Expand Down
151 changes: 151 additions & 0 deletions sjsonnet/test/src/sjsonnet/FoldlMergePatchSpecializationTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package sjsonnet

import utest._
import TestUtils.{assertSameEvalWithAndWithoutSpecialization, evalErr}

object FoldlMergePatchSpecializationTests extends TestSuite {

@noinline
private def check(s: String): Unit = {
assertSameEvalWithAndWithoutSpecialization(s)
}

def tests = Tests {
test("empty array handling") {
check("""std.foldl(std.mergePatch, [], {})""")
check("""std.foldl(std.mergePatch, [{}], {})""")
}

test("single patch to empty object") {
check("""std.foldl(std.mergePatch, [1], {})""")
check("""std.foldl(std.mergePatch, [null], {})""")
check("""std.foldl(std.mergePatch, [{}], {})""")
check("""std.foldl(std.mergePatch, [{a: 1}], {})""")
check("""std.foldl(std.mergePatch, [{a: null}], {})""")
check("""std.foldl(std.mergePatch, [{a: {b: null}}], {})""")
check("""std.foldl(std.mergePatch, [{a: {b: {c: null}}}], {})""")
check("""std.objectFieldsAll(std.foldl(std.mergePatch, [{a: 1, b:: 1}], {}))""")
check("""std.objectFieldsAll(std.foldl(std.mergePatch, [{a: {b: { c:: 1, d: 2}}}], {}).a.b)""")
}

test("basic non-nested merging") {
check("""std.foldl(std.mergePatch, [{a: 1}, {b: 2}], {})""")
check("""std.foldl(std.mergePatch, [{b: 2}], {a: 1})""")

check("""std.foldl(std.mergePatch, [{a: 1}, {b: 2}, {c: 3}], {})""")
check("""std.foldl(std.mergePatch, [{b: 2}, {c: 3}], {a: 1})""")

check("""std.foldl(std.mergePatch, [{a: 1}, {a: 2}, {a: 3}], {})""")
check("""std.foldl(std.mergePatch, [{a: 2}, {a: 3}], {a: 1})""")

check("""std.foldl(std.mergePatch, [{a: 1, b: 1}, {b: 2, c: 2}, {c: 3, d: 3}], {})""")
check("""std.foldl(std.mergePatch, [{b: 2, c: 2}, {c: 3, d: 3}], {a: 1, b: 1})""")
}

test("merging of non-object patches") {
check("""std.foldl(std.mergePatch, [{a: 1}, 1, 2, {a: 3}], {})""")
check("""std.foldl(std.mergePatch, [1, 2, 3, 4], {})""")
check("""std.foldl(std.mergePatch, [1, {a: 1}], {})""")
}

test("nested object merging") {
check("""std.foldl(std.mergePatch, [{a: {x: 1}}, {a: {y: 2}}], {})""")
check("""std.foldl(std.mergePatch, [{a: {x: 1, y: 1}}, {a: {y: 2}}], {})""")
check("""std.foldl(std.mergePatch, [{a: {b: {x: 1}}}, {a: {b: {y: 2}}}], {})""")
check("""std.foldl(std.mergePatch, [{a: {x: 1}, b: 1}, {a: {y: 2}, c: 2}], {})""")
}

test("null handling") {
check("""std.foldl(std.mergePatch, [{a: {x: 1, y: 1}}, {a: {y: null}}], {})""")
check("""std.foldl(std.mergePatch, [{a: 1, b: 1}, {a: null}], {})""")
check("""std.foldl(std.mergePatch, [{a: 1}, {a: null}, {a: 2}], {})""")

check("""
local arr = [
{a: {x: 1, y: 1}, b: 1},
{a: {y: null}, b: null},
{a: {z: 3}}
];
std.foldl(std.mergePatch, arr, {})
""")
}

test("hidden field handling") {
// Hidden fields should always be dropped
check("""std.foldl(std.mergePatch, [{a:: 1}, {b: 2}], {})""")
check("""std.objectFieldsAll(std.foldl(std.mergePatch, [{a:: 1}, {b: 2}], {}))""")
check("""std.foldl(std.mergePatch, [{b: 2}], {a:: 1})""")
check("""std.objectFieldsAll(std.foldl(std.mergePatch, [{b: 2}], {a:: 1}))""")

check("""std.foldl(std.mergePatch, [], {a:: 1})""")
check("""std.objectFieldsAll(std.foldl(std.mergePatch, [], {a:: 1}))""")

check("""std.foldl(std.mergePatch, [{b: 2}], {a:: 1})""")
check("""std.objectFieldsAll(std.foldl(std.mergePatch, [{b: 2}], {a:: 1}))""")

check("""std.foldl(std.mergePatch, [{a: 1}, {b:: 2}], {})""")
check("""std.objectFieldsAll(std.foldl(std.mergePatch, [{a: 1}, {b:: 2}], {}))""")

// Nested hidden fields should also be dropped
check("""
local arr = [{a: {h:: 1, v: 1}}, {b: 2}];
std.foldl(std.mergePatch, arr, {})
""")
check(
"""
local arr = [{a: {h:: 1, v: 1}}, {b: 2}];
std.objectFieldsAll(std.foldl(std.mergePatch, arr, {}).a)
""")

// Hidden fields do not merge with non-hidden fields
check("""std.foldl(std.mergePatch, [{a: {b: 1}}, {a:: {c: 1}}], {})""")
}

test("ordering preservation") {
check("""std.foldl(std.mergePatch, [{b: 1, a: 1}, {c: 1, d: 1}], {})""")
check("""std.foldl(std.mergePatch, [{a: 2, b: 2}], {b: 1, a: 1})""")
}

test("plus operator handling") {
// The +: operator should be ignored during merging
check("""std.foldl(std.mergePatch, [{a: 1}, {a+: 2}], {})""")
check("""std.foldl(std.mergePatch, [{a+: 2}], {a: 1})""")

// The resulting field should not be treated as +: in future merges
check("""
local result = std.foldl(std.mergePatch, [{}, {a+: 2}], {});
{a: 1} + result
""")
check(
"""
local result = std.foldl(std.mergePatch, [{}, {a+: 2}], {});
{a: 1} + result
""")

// +: in first object should also be ignored
check("""
local result = std.foldl(std.mergePatch, [{}], {a+: 2});
{a: 1} + result
""")

// Should work the same for nested fields
check("""
local result = std.foldl(std.mergePatch, [{a: {b+: 2}}, {}], {});
{a: {b: 1}} + result
""")
}

test("error handling") {
assert(evalErr(
"""std.foldl(std.mergePatch, null, {})""",
disableStaticApplyForBuiltinFunctions = true,
disableBuiltinSpecialization = true
).startsWith("sjsonnet.Error: Cannot call foldl on null"))
assert(evalErr(
"""std.foldl(std.mergePatch, null, {})""",
disableStaticApplyForBuiltinFunctions = true,
disableBuiltinSpecialization = false
).startsWith("sjsonnet.Error: Expected array, got null"))
}
}
}
Loading
Loading