From 1123675944853c7bebba1aab54f3b4bea4d6bc17 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 7 Feb 2023 18:58:58 -0500 Subject: [PATCH] [Util] NestedMsg: MapNestedMsg (#410) This PR brings a new util function (`MapNestedMsg`) for NestedMsg: * `MapNestedMsg` recurses into an input NestedMsg and maps the leaves according to the provided mapping function. One corresponding unit test is provided. --- include/tvm/relax/nested_msg.h | 26 ++++++++++++++++++++++++++ tests/cpp/nested_msg_test.cc | 23 +++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h index 734045eee3..e0f60be36b 100644 --- a/include/tvm/relax/nested_msg.h +++ b/include/tvm/relax/nested_msg.h @@ -427,6 +427,32 @@ NestedMsg CombineNestedMsg(NestedMsg lhs, NestedMsg rhs, FType fcombine } } +/*! + * \brief Recursively map a nested message to another one, with leaf mapped by the input fmapleaf. + * \param msg The nested message to be mapped. + * \param fmapleaf The leaf map function, with signature NestedMsg fmapleaf(T msg) + * \tparam T The content type of nested message. + * \tparam FType The leaf map function type. + * \return The new nested message. + */ +template +NestedMsg MapNestedMsg(NestedMsg msg, FType fmapleaf) { + if (msg.IsNull()) { + return msg; + } else if (msg.IsLeaf()) { + return fmapleaf(msg.LeafValue()); + } else { + ICHECK(msg.IsNested()); + Array> arr = msg.NestedArray(); + Array> res; + res.reserve(arr.size()); + for (int i = 0; i < static_cast(arr.size()); ++i) { + res.push_back(MapNestedMsg(arr[i], fmapleaf)); + } + return NestedMsg(res); + } +} + /*! * \brief Recursively decompose the tuple structure in expr and msg along with it. * diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc index 8e5c22b18e..550dc37515 100644 --- a/tests/cpp/nested_msg_test.cc +++ b/tests/cpp/nested_msg_test.cc @@ -260,6 +260,29 @@ TEST(NestedMsg, CombineNestedMsg) { [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); } +TEST(NestedMsg, MapNestedMsg) { + auto c0 = Integer(0); + auto c1 = Integer(1); + auto c2 = Integer(2); + auto c3 = Integer(3); + + NestedMsg msg = {c0, {c0, c1}, NullOpt, {c0, {c2, c1}}}; + NestedMsg expected = {c3, {c3, NullOpt}, NullOpt, {c3, {c2, NullOpt}}}; + + auto output = MapNestedMsg(msg, [](Integer x) { + if (x->value == 0) { + return NestedMsg(Integer(3)); + } else if (x->value == 1) { + return NestedMsg(); + } else { + return NestedMsg(x); + } + }); + + EXPECT_TRUE(Equal(output, expected, + [](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; })); +} + TEST(NestedMsg, TransformTupleLeaf) { auto c0 = Integer(0); auto c1 = Integer(1);