From 4d82ab57626f2036e6b7b928628456264f88fea7 Mon Sep 17 00:00:00 2001 From: Yuval Tassa Date: Tue, 7 Jan 2025 04:41:09 -0800 Subject: [PATCH] Add function to merge chains of indices. PiperOrigin-RevId: 712863690 Change-Id: I15652bec03dc9ce90e230788bd12b44b1a2b8217 --- src/engine/engine_util_sparse.c | 11 ----- src/engine/engine_util_sparse.h | 65 ++++++++++++++++++++++++++ test/engine/engine_util_sparse_test.cc | 18 +++++++ 3 files changed, 83 insertions(+), 11 deletions(-) diff --git a/src/engine/engine_util_sparse.c b/src/engine/engine_util_sparse.c index e4e1c1c98b..6fbb0ff4e1 100644 --- a/src/engine/engine_util_sparse.c +++ b/src/engine/engine_util_sparse.c @@ -206,17 +206,6 @@ static void mju_addToSclScl(mjtNum* res, const mjtNum* vec, mjtNum scl1, mjtNum -// return 1 if vec1==vec2, 0 otherwise -static int mju_compare(const int* vec1, const int* vec2, int n) { -#ifdef mjUSEAVX - return mju_compare_avx(vec1, vec2, n); -#else - return !memcmp(vec1, vec2, n*sizeof(int)); -#endif // mjUSEAVX -} - - - // count the number of non-zeros in the sum of two sparse vectors int mju_combineSparseCount(int a_nnz, int b_nnz, const int* a_ind, const int* b_ind) { int a = 0, b = 0, c_nnz = 0; diff --git a/src/engine/engine_util_sparse.h b/src/engine/engine_util_sparse.h index 04c376103b..34bc777fc7 100644 --- a/src/engine/engine_util_sparse.h +++ b/src/engine/engine_util_sparse.h @@ -15,6 +15,8 @@ #ifndef MUJOCO_SRC_ENGINE_ENGINE_UTIL_SPARSE_H_ #define MUJOCO_SRC_ENGINE_ENGINE_UTIL_SPARSE_H_ +#include + #include #include #include @@ -162,6 +164,69 @@ mjtNum mju_dotSparse(const mjtNum* vec1, const mjtNum* vec2, int nnz1, const int #endif // mjUSEAVX } +// return 1 if vec1==vec2, 0 otherwise +static inline +int mju_compare(const int* vec1, const int* vec2, int n) { +#ifdef mjUSEAVX + return mju_compare_avx(vec1, vec2, n); +#else + return !memcmp(vec1, vec2, n*sizeof(int)); +#endif // mjUSEAVX +} + + +// merge unique sorted integers, merge array must be large enough (not checked for) +static inline +int mj_mergeSorted(int* merge, const int* chain1, int n1, const int* chain2, int n2) { + // special case: one or both empty + if (n1 == 0) { + if (n2 == 0) { + return 0; + } + memcpy(merge, chain2, n2 * sizeof(int)); + return n2; + } else if (n2 == 0) { + memcpy(merge, chain1, n1 * sizeof(int)); + return n1; + } + + // special case: identical pattern + if (n1 == n2 && mju_compare(chain1, chain2, n1)) { + memcpy(merge, chain1, n1 * sizeof(int)); + return n1; + } + + // merge while both chains are non-empty + int i = 0, j = 0, k = 0; + while (i < n1 && j < n2) { + int c1 = chain1[i]; + int c2 = chain2[j]; + + if (c1 < c2) { + merge[k++] = c1; + i++; + } else if (c1 > c2) { + merge[k++] = c2; + j++; + } else { // c1 == c2 + merge[k++] = c1; + i++; + j++; + } + } + + // copy remaining + if (i < n1) { + memcpy(merge + k, chain1 + i, (n1 - i)*sizeof(int)); + k += n1 - i; + } else if (j < n2) { + memcpy(merge + k, chain2 + j, (n2 - j)*sizeof(int)); + k += n2 - j; + } + + return k; +} + #ifdef __cplusplus } diff --git a/test/engine/engine_util_sparse_test.cc b/test/engine/engine_util_sparse_test.cc index d1057b225a..20a543dfe8 100644 --- a/test/engine/engine_util_sparse_test.cc +++ b/test/engine/engine_util_sparse_test.cc @@ -1100,5 +1100,23 @@ TEST_F(EngineUtilSparseTest, MjuDenseToSparse) { EXPECT_EQ(status0, 1); } +TEST_F(EngineUtilSparseTest, MergeSorted) { + const int chain1_a[] = {1, 2, 3}; + const int chain2_a[] = {}; + int merged_a[3]; + int n1 = 3; + int n2 = 0; + EXPECT_EQ(mj_mergeSorted(merged_a, chain1_a, n1, chain2_a, n2), 3); + EXPECT_THAT(merged_a, ElementsAre(1, 2, 3)); + + const int chain1_b[] = {1, 3, 5, 7, 8}; + const int chain2_b[] = {2, 4, 5, 6, 8}; + int merged_b[8]; + n1 = 5; + n2 = 5; + EXPECT_EQ(mj_mergeSorted(merged_b, chain1_b, n1, chain2_b, n2), 8); + EXPECT_THAT(merged_b, ElementsAre(1, 2, 3, 4, 5, 6, 7, 8)); +} + } // namespace } // namespace mujoco