diff --git a/sjsonnet/src/sjsonnet/Settings.scala b/sjsonnet/src/sjsonnet/Settings.scala index 1f13b713..c0498427 100644 --- a/sjsonnet/src/sjsonnet/Settings.scala +++ b/sjsonnet/src/sjsonnet/Settings.scala @@ -11,6 +11,8 @@ class Settings( val strictInheritedAssertions: Boolean = false, val strictSetOperations: Boolean = false, val throwErrorForInvalidSets: Boolean = false, + val disableBuiltinSpecialization: Boolean = false, + val disableStaticApplyForBuiltinFunctions: Boolean = false ) object Settings { diff --git a/sjsonnet/src/sjsonnet/StaticOptimizer.scala b/sjsonnet/src/sjsonnet/StaticOptimizer.scala index 6e1e2854..377589e6 100644 --- a/sjsonnet/src/sjsonnet/StaticOptimizer.scala +++ b/sjsonnet/src/sjsonnet/StaticOptimizer.scala @@ -122,7 +122,8 @@ class StaticOptimizer( } private def tryStaticApply(pos: Position, f: Val.Builtin, args: Array[Expr]): Expr = { - if(f.staticSafe && args.forall(_.isInstanceOf[Val])) { + if (ev.settings.disableStaticApplyForBuiltinFunctions) null + else if(f.staticSafe && args.forall(_.isInstanceOf[Val])) { val vargs = args.map(_.asInstanceOf[Val]) try f.apply(vargs, null, pos)(ev).asInstanceOf[Expr] catch { case _: Exception => return null } } else null @@ -146,9 +147,13 @@ class StaticOptimizer( case newArgs => tryStaticApply(pos, f, newArgs) match { case null => - val (f2, rargs) = f.specialize(newArgs) match { - case null => (f, newArgs) - case (f2, a2) => (f2, a2) + val (f2, rargs) = if (ev.settings.disableBuiltinSpecialization) { + (f, newArgs) + } else { + f.specialize(newArgs) match { + case null => (f, newArgs) + case (f2, a2) => (f2, a2) + } } val alen = rargs.length f2 match { diff --git a/sjsonnet/src/sjsonnet/Std.scala b/sjsonnet/src/sjsonnet/Std.scala index 4d594f7d..941ba1cb 100644 --- a/sjsonnet/src/sjsonnet/Std.scala +++ b/sjsonnet/src/sjsonnet/Std.scala @@ -254,8 +254,181 @@ class Std(private val additionalNativeFunctions: Map[String, Val.Builtin] = Map. case _ => Error.fail("Cannot call foldl on " + arr.prettyName) } + } + + override def specialize(args: Array[Expr]): (Val.Builtin, Array[Expr]) = { + try { + args match { + // Specialize std.foldl(std.mergePatch, arr, target) + case Array(f: Val.Builtin, arr, target: Val.Obj) if f.functionName == "mergePatch" => + (FoldlMergePatch, Array(arr, target)) + } + } catch { + case _: Exception => null + } + } + } + + /** + * Optimized equivalent of `std.foldl(std.mergePatch, arr, target)` + */ + private object FoldlMergePatch extends Val.Builtin2("mergePatchAll", "array", "target") { + + /** + * Recursively process an object by: + * Removing any fields explicitly set to null. + * Removing any non-visible fields. + * Stripping the `add` modifier off of any fields with it. + */ + private def cleanObject(obj: Val.Obj, ev: EvalScope): Val.Obj = { + val visibleKeys = obj.visibleKeyNames + val newFields: Array[Val] = new Array(visibleKeys.length) + var i = 0 + var updateNeeded: Boolean = (visibleKeys.length != obj.allKeysSize) || obj.hasAPlusField + while (i < visibleKeys.length) { + val key = visibleKeys(i) + val value = obj.value(key, obj.pos.noOffset, obj)(ev) + if (value.isInstanceOf[Val.Null]) { + newFields(i) = null + updateNeeded = true + } else { + val newValue = value match { + case obj: Val.Obj => cleanObject(obj, ev) + case _ => value + } + newFields(i) = newValue + if (newValue ne value) updateNeeded = true + } + i += 1 + } + if (updateNeeded) { + val newFieldsMap = new util.LinkedHashMap[String, Val.Obj.Member](visibleKeys.length) + i = 0 + while (i < visibleKeys.length) { + val key = visibleKeys(i) + val value = newFields(i) + if (value != null) { + newFieldsMap.put(key, createMember(value)) + } + i += 1 + } + new Val.Obj(obj.pos, newFieldsMap, false, null, null) + } else { + obj + } + } + private def createMember(v: Val) = new Val.Obj.Member(false, Visibility.Unhide) { + def invoke(self: Val.Obj, sup: Val.Obj, fs: FileScope, ev: EvalScope): Val = v + } + + // Placeholder to represent absence of a value. We use this instead of `null` because + // LinkedHashMap.putIfAbsent still affects insertion order if the existing value is null. + private[this] val nullCanary = new Val.Obj.ConstMember(false, Visibility.Normal, Val.Null(pos)) + + def evalRhs(arr: Val, target: Val, ev: EvalScope, pos: Position): Val = { + // Here, `objectSize` is the number of valid entries in the `objects` array. + // This is an optimization to avoid having to trim or resize intermediate arrays. + def recMerge(target: Val, objects: Array[Val.Obj], objectsSize: Int): Val.Obj = { + // Determine an upper bound of the final key set (only a bound because a key + // might end up being removed and we can only know that after further processing). + // We need an `outputFields` LinkedHashMap anyways, so we'll first use it to + // collect the distinct fields and then will update it to either populate the members + // or remove fields that were later determined to be unused. + val outputFields = new util.LinkedHashMap[String, Val.Obj.Member]() + target match { + case t: Val.Obj => + t.visibleKeyNames.foreach(k => outputFields.putIfAbsent(k, nullCanary)) + case _ => + } + var idx = 0 + while (idx < objectsSize) { + objects(idx).visibleKeyNames.foreach(k => outputFields.putIfAbsent(k, nullCanary)) + idx += 1 + } + + // Perform the merge for each key: + val keysIter = outputFields.keySet().iterator() + val objValues = new Array[Val.Obj](objectsSize) + while (keysIter.hasNext) { + val key = keysIter.next() + val targetValue: Val = target match { + case targetObj: Val.Obj if targetObj.containsVisibleKey(key) => + targetObj.valueRaw(key, targetObj, pos)(ev) + case _ => + null + } + + var lastValue: Val = null + var objCount = 0 + var i = 0 + // Loop over the patches, determining either the final non-object + // value which overwrites the target key or determining the subset + // of patches which contribute to the final value. + while (i < objectsSize) { + val obj = objects(i) + if (obj.containsVisibleKey(key)) { + lastValue = obj.valueRaw(key, obj, pos)(ev) + lastValue match { + case _: Val.Obj => + objValues(objCount) = lastValue.asInstanceOf[Val.Obj] + objCount += 1 + case _ => + // Got either a Null or a non-Obj, but in either case we + // won't use any of the earlier values for this key in the + // merged result since they would be overwritten and discarded + // at this step. + objCount = 0 + } + } + i += 1 + } + val removeField = lastValue.isInstanceOf[Val.Null] + if (removeField) { + keysIter.remove() + } else { + val finalValue = { + if (objCount > 1) recMerge(targetValue, objValues, objCount) + else if (objCount == 1) cleanObject(objValues(0), ev) + else if (lastValue != null) lastValue + else targetValue + } + outputFields.replace(key, createMember(finalValue)) + } + } + new Val.Obj(pos, outputFields, false, null, null) + } + + arr match { + case arr: Val.Arr => + val length = arr.length + if (length == 0) { + target + } else { + val objects = new Array[Val.Obj](length) + var arrIdx = 0 + var objectsIdx = 0 + while (arrIdx < length) { + arr.force(arrIdx) match { + case obj: Val.Obj => + objects(objectsIdx) = obj + objectsIdx += 1 + case _ => + objectsIdx = 0 + } + arrIdx += 1 + } + if (objectsIdx == 0) { + // The last element is a non-object, so it overwrites everything. + arr.force(length - 1) + } else { + recMerge(target, objects, objectsIdx) + } + } + + case v => Error.fail(s"Expected array, got ${v.prettyName}", pos)(ev) + } } } diff --git a/sjsonnet/src/sjsonnet/Val.scala b/sjsonnet/src/sjsonnet/Val.scala index 9e02311d..cb8a5efd 100644 --- a/sjsonnet/src/sjsonnet/Val.scala +++ b/sjsonnet/src/sjsonnet/Val.scala @@ -8,6 +8,7 @@ import sjsonnet.Expr.Params import scala.annotation.tailrec import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters._ import scala.reflect.ClassTag /** @@ -227,6 +228,9 @@ object Val{ @inline def containsVisibleKey(k: String): Boolean = getAllKeys.get(k) == java.lang.Boolean.FALSE + def allKeysSize: Int = getAllKeys.size() + def hasAPlusField: Boolean = getValue0.values().asScala.exists(_.add) + lazy val allKeyNames: Array[String] = getAllKeys.keySet().toArray(new Array[String](getAllKeys.size())) lazy val visibleKeyNames: Array[String] = if(static) allKeyNames else { diff --git a/sjsonnet/test/src/sjsonnet/FoldlMergePatchSpecializationTests.scala b/sjsonnet/test/src/sjsonnet/FoldlMergePatchSpecializationTests.scala new file mode 100644 index 00000000..73770c16 --- /dev/null +++ b/sjsonnet/test/src/sjsonnet/FoldlMergePatchSpecializationTests.scala @@ -0,0 +1,151 @@ +package sjsonnet + +import utest._ +import TestUtils.{assertSameEvalWithAndWithoutSpecialization, evalErr} + +object FoldlMergePatchSpecializationTests extends TestSuite { + + @noinline + private def check(s: String): Unit = { + assertSameEvalWithAndWithoutSpecialization(s) + } + + def tests = Tests { + test("empty array handling") { + check("""std.foldl(std.mergePatch, [], {})""") + check("""std.foldl(std.mergePatch, [{}], {})""") + } + + test("single patch to empty object") { + check("""std.foldl(std.mergePatch, [1], {})""") + check("""std.foldl(std.mergePatch, [null], {})""") + check("""std.foldl(std.mergePatch, [{}], {})""") + check("""std.foldl(std.mergePatch, [{a: 1}], {})""") + check("""std.foldl(std.mergePatch, [{a: null}], {})""") + check("""std.foldl(std.mergePatch, [{a: {b: null}}], {})""") + check("""std.foldl(std.mergePatch, [{a: {b: {c: null}}}], {})""") + check("""std.objectFieldsAll(std.foldl(std.mergePatch, [{a: 1, b:: 1}], {}))""") + check("""std.objectFieldsAll(std.foldl(std.mergePatch, [{a: {b: { c:: 1, d: 2}}}], {}).a.b)""") + } + + test("basic non-nested merging") { + check("""std.foldl(std.mergePatch, [{a: 1}, {b: 2}], {})""") + check("""std.foldl(std.mergePatch, [{b: 2}], {a: 1})""") + + check("""std.foldl(std.mergePatch, [{a: 1}, {b: 2}, {c: 3}], {})""") + check("""std.foldl(std.mergePatch, [{b: 2}, {c: 3}], {a: 1})""") + + check("""std.foldl(std.mergePatch, [{a: 1}, {a: 2}, {a: 3}], {})""") + check("""std.foldl(std.mergePatch, [{a: 2}, {a: 3}], {a: 1})""") + + check("""std.foldl(std.mergePatch, [{a: 1, b: 1}, {b: 2, c: 2}, {c: 3, d: 3}], {})""") + check("""std.foldl(std.mergePatch, [{b: 2, c: 2}, {c: 3, d: 3}], {a: 1, b: 1})""") + } + + test("merging of non-object patches") { + check("""std.foldl(std.mergePatch, [{a: 1}, 1, 2, {a: 3}], {})""") + check("""std.foldl(std.mergePatch, [1, 2, 3, 4], {})""") + check("""std.foldl(std.mergePatch, [1, {a: 1}], {})""") + } + + test("nested object merging") { + check("""std.foldl(std.mergePatch, [{a: {x: 1}}, {a: {y: 2}}], {})""") + check("""std.foldl(std.mergePatch, [{a: {x: 1, y: 1}}, {a: {y: 2}}], {})""") + check("""std.foldl(std.mergePatch, [{a: {b: {x: 1}}}, {a: {b: {y: 2}}}], {})""") + check("""std.foldl(std.mergePatch, [{a: {x: 1}, b: 1}, {a: {y: 2}, c: 2}], {})""") + } + + test("null handling") { + check("""std.foldl(std.mergePatch, [{a: {x: 1, y: 1}}, {a: {y: null}}], {})""") + check("""std.foldl(std.mergePatch, [{a: 1, b: 1}, {a: null}], {})""") + check("""std.foldl(std.mergePatch, [{a: 1}, {a: null}, {a: 2}], {})""") + + check(""" + local arr = [ + {a: {x: 1, y: 1}, b: 1}, + {a: {y: null}, b: null}, + {a: {z: 3}} + ]; + std.foldl(std.mergePatch, arr, {}) + """) + } + + test("hidden field handling") { + // Hidden fields should always be dropped + check("""std.foldl(std.mergePatch, [{a:: 1}, {b: 2}], {})""") + check("""std.objectFieldsAll(std.foldl(std.mergePatch, [{a:: 1}, {b: 2}], {}))""") + check("""std.foldl(std.mergePatch, [{b: 2}], {a:: 1})""") + check("""std.objectFieldsAll(std.foldl(std.mergePatch, [{b: 2}], {a:: 1}))""") + + check("""std.foldl(std.mergePatch, [], {a:: 1})""") + check("""std.objectFieldsAll(std.foldl(std.mergePatch, [], {a:: 1}))""") + + check("""std.foldl(std.mergePatch, [{b: 2}], {a:: 1})""") + check("""std.objectFieldsAll(std.foldl(std.mergePatch, [{b: 2}], {a:: 1}))""") + + check("""std.foldl(std.mergePatch, [{a: 1}, {b:: 2}], {})""") + check("""std.objectFieldsAll(std.foldl(std.mergePatch, [{a: 1}, {b:: 2}], {}))""") + + // Nested hidden fields should also be dropped + check(""" + local arr = [{a: {h:: 1, v: 1}}, {b: 2}]; + std.foldl(std.mergePatch, arr, {}) + """) + check( + """ + local arr = [{a: {h:: 1, v: 1}}, {b: 2}]; + std.objectFieldsAll(std.foldl(std.mergePatch, arr, {}).a) + """) + + // Hidden fields do not merge with non-hidden fields + check("""std.foldl(std.mergePatch, [{a: {b: 1}}, {a:: {c: 1}}], {})""") + } + + test("ordering preservation") { + check("""std.foldl(std.mergePatch, [{b: 1, a: 1}, {c: 1, d: 1}], {})""") + check("""std.foldl(std.mergePatch, [{a: 2, b: 2}], {b: 1, a: 1})""") + } + + test("plus operator handling") { + // The +: operator should be ignored during merging + check("""std.foldl(std.mergePatch, [{a: 1}, {a+: 2}], {})""") + check("""std.foldl(std.mergePatch, [{a+: 2}], {a: 1})""") + + // The resulting field should not be treated as +: in future merges + check(""" + local result = std.foldl(std.mergePatch, [{}, {a+: 2}], {}); + {a: 1} + result + """) + check( + """ + local result = std.foldl(std.mergePatch, [{}, {a+: 2}], {}); + {a: 1} + result + """) + + // +: in first object should also be ignored + check(""" + local result = std.foldl(std.mergePatch, [{}], {a+: 2}); + {a: 1} + result + """) + + // Should work the same for nested fields + check(""" + local result = std.foldl(std.mergePatch, [{a: {b+: 2}}, {}], {}); + {a: {b: 1}} + result + """) + } + + test("error handling") { + assert(evalErr( + """std.foldl(std.mergePatch, null, {})""", + disableStaticApplyForBuiltinFunctions = true, + disableBuiltinSpecialization = true + ).startsWith("sjsonnet.Error: Cannot call foldl on null")) + assert(evalErr( + """std.foldl(std.mergePatch, null, {})""", + disableStaticApplyForBuiltinFunctions = true, + disableBuiltinSpecialization = false + ).startsWith("sjsonnet.Error: Expected array, got null")) + } + } +} \ No newline at end of file diff --git a/sjsonnet/test/src/sjsonnet/TestUtils.scala b/sjsonnet/test/src/sjsonnet/TestUtils.scala index 298f7e6c..064cf159 100644 --- a/sjsonnet/test/src/sjsonnet/TestUtils.scala +++ b/sjsonnet/test/src/sjsonnet/TestUtils.scala @@ -8,7 +8,9 @@ object TestUtils { strict: Boolean = false, noDuplicateKeysInComprehension: Boolean = false, strictInheritedAssertions: Boolean = false, - strictSetOperations: Boolean = true): Either[String, Value] = { + strictSetOperations: Boolean = true, + disableBuiltinSpecialization: Boolean = false, + disableStaticApplyForBuiltinFunctions: Boolean = false): Either[String, Value] = { new Interpreter( Map(), Map(), @@ -21,6 +23,8 @@ object TestUtils { noDuplicateKeysInComprehension = noDuplicateKeysInComprehension, strictInheritedAssertions = strictInheritedAssertions, strictSetOperations = strictSetOperations, + disableBuiltinSpecialization = disableBuiltinSpecialization, + disableStaticApplyForBuiltinFunctions = disableStaticApplyForBuiltinFunctions, throwErrorForInvalidSets = true ) ).interpret(s, DummyPath("(memory)")) @@ -31,8 +35,18 @@ object TestUtils { strict: Boolean = false, noDuplicateKeysInComprehension: Boolean = false, strictInheritedAssertions: Boolean = false, - strictSetOperations: Boolean = true): Value = { - eval0(s, preserveOrder, strict, noDuplicateKeysInComprehension, strictInheritedAssertions, strictSetOperations) match { + strictSetOperations: Boolean = true, + disableBuiltinSpecialization: Boolean = false, + disableStaticApplyForBuiltinFunctions: Boolean = false): Value = { + eval0( + s, + preserveOrder, + strict, + noDuplicateKeysInComprehension, + strictInheritedAssertions, + strictSetOperations, + disableBuiltinSpecialization, + disableStaticApplyForBuiltinFunctions) match { case Right(x) => x case Left(e) => throw new Exception(e) } @@ -43,10 +57,48 @@ object TestUtils { strict: Boolean = false, noDuplicateKeysInComprehension: Boolean = false, strictInheritedAssertions: Boolean = false, - strictSetOperations: Boolean = true): String = { - eval0(s, preserveOrder, strict, noDuplicateKeysInComprehension, strictInheritedAssertions, strictSetOperations) match{ + strictSetOperations: Boolean = true, + disableBuiltinSpecialization: Boolean = false, + disableStaticApplyForBuiltinFunctions: Boolean = false): String = { + eval0( + s, + preserveOrder, + strict, + noDuplicateKeysInComprehension, + strictInheritedAssertions, + strictSetOperations, + disableBuiltinSpecialization, + disableStaticApplyForBuiltinFunctions) match { case Left(err) => err.split('\n').map(_.trim).mkString("\n") // normalize inconsistent indenation on JVM vs JS case Right(r) => throw new Exception(s"Expected exception, got result: $r") } } + + def assertSameEvalWithAndWithoutSpecialization(s: String): Unit = { + // We have to disable static application of built-in functions, otherwise + // it will be folded in the static optimizer before we even have a chance to + // perform specialization. + val noSpecialization = eval( + s, + preserveOrder = true, + disableStaticApplyForBuiltinFunctions = true, + disableBuiltinSpecialization = true) + val withSpecialization = eval( + s, + preserveOrder = true, + disableStaticApplyForBuiltinFunctions = true, + disableBuiltinSpecialization = false) + // For better error messages, convert to string representation first + val specializedStr = noSpecialization.toString() + val unspecializedStr = withSpecialization.toString() + + assert( + specializedStr == unspecializedStr, + s"""Specialization mismatch for expression: $s + | + |Specialized result: $specializedStr + |Unspecialized result: $unspecializedStr + |""".stripMargin + ) + } }