Skip to content

Commit

Permalink
Add function to merge chains of indices.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712863690
Change-Id: I15652bec03dc9ce90e230788bd12b44b1a2b8217
  • Loading branch information
yuvaltassa authored and copybara-github committed Jan 7, 2025
1 parent 1c69e64 commit 4d82ab5
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 11 deletions.
11 changes: 0 additions & 11 deletions src/engine/engine_util_sparse.c
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,6 @@ static void mju_addToSclScl(mjtNum* res, const mjtNum* vec, mjtNum scl1, mjtNum



// return 1 if vec1==vec2, 0 otherwise
static int mju_compare(const int* vec1, const int* vec2, int n) {
#ifdef mjUSEAVX
return mju_compare_avx(vec1, vec2, n);
#else
return !memcmp(vec1, vec2, n*sizeof(int));
#endif // mjUSEAVX
}



// count the number of non-zeros in the sum of two sparse vectors
int mju_combineSparseCount(int a_nnz, int b_nnz, const int* a_ind, const int* b_ind) {
int a = 0, b = 0, c_nnz = 0;
Expand Down
65 changes: 65 additions & 0 deletions src/engine/engine_util_sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#ifndef MUJOCO_SRC_ENGINE_ENGINE_UTIL_SPARSE_H_
#define MUJOCO_SRC_ENGINE_ENGINE_UTIL_SPARSE_H_

#include <string.h>

#include <mujoco/mjdata.h>
#include <mujoco/mjexport.h>
#include <mujoco/mjtnum.h>
Expand Down Expand Up @@ -162,6 +164,69 @@ mjtNum mju_dotSparse(const mjtNum* vec1, const mjtNum* vec2, int nnz1, const int
#endif // mjUSEAVX
}

// return 1 if vec1==vec2, 0 otherwise
static inline
int mju_compare(const int* vec1, const int* vec2, int n) {
#ifdef mjUSEAVX
return mju_compare_avx(vec1, vec2, n);
#else
return !memcmp(vec1, vec2, n*sizeof(int));
#endif // mjUSEAVX
}


// merge unique sorted integers, merge array must be large enough (not checked for)
static inline
int mj_mergeSorted(int* merge, const int* chain1, int n1, const int* chain2, int n2) {
// special case: one or both empty
if (n1 == 0) {
if (n2 == 0) {
return 0;
}
memcpy(merge, chain2, n2 * sizeof(int));
return n2;
} else if (n2 == 0) {
memcpy(merge, chain1, n1 * sizeof(int));
return n1;
}

// special case: identical pattern
if (n1 == n2 && mju_compare(chain1, chain2, n1)) {
memcpy(merge, chain1, n1 * sizeof(int));
return n1;
}

// merge while both chains are non-empty
int i = 0, j = 0, k = 0;
while (i < n1 && j < n2) {
int c1 = chain1[i];
int c2 = chain2[j];

if (c1 < c2) {
merge[k++] = c1;
i++;
} else if (c1 > c2) {
merge[k++] = c2;
j++;
} else { // c1 == c2
merge[k++] = c1;
i++;
j++;
}
}

// copy remaining
if (i < n1) {
memcpy(merge + k, chain1 + i, (n1 - i)*sizeof(int));
k += n1 - i;
} else if (j < n2) {
memcpy(merge + k, chain2 + j, (n2 - j)*sizeof(int));
k += n2 - j;
}

return k;
}


#ifdef __cplusplus
}
Expand Down
18 changes: 18 additions & 0 deletions test/engine/engine_util_sparse_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1100,5 +1100,23 @@ TEST_F(EngineUtilSparseTest, MjuDenseToSparse) {
EXPECT_EQ(status0, 1);
}

TEST_F(EngineUtilSparseTest, MergeSorted) {
const int chain1_a[] = {1, 2, 3};
const int chain2_a[] = {};
int merged_a[3];
int n1 = 3;
int n2 = 0;
EXPECT_EQ(mj_mergeSorted(merged_a, chain1_a, n1, chain2_a, n2), 3);
EXPECT_THAT(merged_a, ElementsAre(1, 2, 3));

const int chain1_b[] = {1, 3, 5, 7, 8};
const int chain2_b[] = {2, 4, 5, 6, 8};
int merged_b[8];
n1 = 5;
n2 = 5;
EXPECT_EQ(mj_mergeSorted(merged_b, chain1_b, n1, chain2_b, n2), 8);
EXPECT_THAT(merged_b, ElementsAre(1, 2, 3, 4, 5, 6, 7, 8));
}

} // namespace
} // namespace mujoco

0 comments on commit 4d82ab5

Please sign in to comment.