Skip to content

Commit

Permalink
Refactor LD back-substitution (CSR version).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 717403904
Change-Id: Idcd5e71f03a960203c22cb737d399fac5a0ba59c
  • Loading branch information
yuvaltassa authored and copybara-github committed Jan 20, 2025
1 parent 2a10054 commit 5c4c79c
Showing 1 changed file with 53 additions and 59 deletions.
112 changes: 53 additions & 59 deletions src/engine/engine_core_smooth.c
Original file line number Diff line number Diff line change
Expand Up @@ -1609,84 +1609,78 @@ void mj_solveLD(const mjModel* m, mjtNum* restrict x, int n,
// like mj_solveLD, but using the CSR representation of L
void mj_solveLDs(mjtNum* restrict x, const mjtNum* qLDs, const mjtNum* qLDiagInv, int nv, int n,
const int* rownnz, const int* rowadr, const int* diagnum, const int* colind) {
// single vector
if (n == 1) {
// x <- L^-T x
for (int i=nv-1; i > 0; i--) {
// skip diagonal rows, zero elements in input vector
mjtNum x_i = x[i];
if (x_i == 0 || diagnum[i]) {
continue;
}

int start = rowadr[i];
int end = start + rownnz[i] - 1;
for (int adr=start; adr < end; adr++) {
x[colind[adr]] -= qLDs[adr] * x_i;
}
}

// x <- D^-1 x
for (int i=0; i < nv; i++) {
x[i] *= qLDiagInv[i];
// x <- L^-T x
for (int i=nv-1; i > 0; i--) {
// skip diagonal rows
if (diagnum[i]) {
continue;
}

// x <- L^-1 x
for (int i=1; i < nv; i++) {
// skip diagonal rows
if (diagnum[i]) {
i += diagnum[i] - 1; // iterating forward: skip ahead, adjust i
continue;
// one vector
if (n == 1) {
mjtNum x_i;
if ((x_i = x[i])) {
int start = rowadr[i];
int end = start + rownnz[i] - 1;
for (int adr=start; adr < end; adr++) {
x[colind[adr]] -= qLDs[adr] * x_i;
}
}

int adr = rowadr[i];
x[i] -= mju_dotSparse(qLDs+adr, x, rownnz[i] - 1, colind+adr, /*flg_unc1=*/0);
}
}

// multiple vectors
else {
// x <- L^-T x
for (int i=nv-1; i > 0; i--) {
// skip diagonal rows
if (diagnum[i]) {
continue;
}

// multiple vectors
else {
int start = rowadr[i];
int end = start + rownnz[i] - 1;
for (int adr=start; adr < end; adr++) {
int j = colind[adr];
mjtNum val = qLDs[adr];
for (int offset=0; offset < n*nv; offset+=nv) {
mjtNum x_i;
if ((x_i = x[i+offset])) {
x[j+offset] -= val * x_i;
for (int offset=0; offset < n*nv; offset+=nv) {
mjtNum x_i;
if ((x_i = x[i+offset])) {
for (int adr=start; adr < end; adr++) {
x[offset + colind[adr]] -= qLDs[adr] * x_i;
}
}
}
}
}

// x <- D^-1 x
for (int i=0; i < nv; i++) {
mjtNum invD_i = qLDiagInv[i];

// one vector
if (n == 1) {
x[i] *= invD_i;
}

// x <- D^-1 x
for (int i=0; i < nv; i++) {
mjtNum invD_i = qLDiagInv[i];
// multiple vectors
else {
for (int offset=0; offset < n*nv; offset+=nv) {
x[i+offset] *= invD_i;
}
}
}

// x <- L^-1 x
for (int i=1; i < nv; i++) {
// skip diagonal rows
if (diagnum[i]) {
i += diagnum[i] - 1; // iterating forward: skip ahead, adjust i
continue;
// x <- L^-1 x
for (int i=1; i < nv; i++) {
// skip diagonal rows
if (diagnum[i]) {
i += diagnum[i] - 1; // iterating forward: skip ahead, adjust i
continue;
}

int adr = rowadr[i];
int d = rownnz[i] - 1;
if (d > 0) {
// one vector
if (n == 1) {
x[i] -= mju_dotSparse(qLDs+adr, x, d, colind+adr, /*flg_unc1=*/0);
}

int adr = rowadr[i];
int d = rownnz[i] - 1;
for (int offset=0; offset < n*nv; offset+=nv) {
x[i+offset] -= mju_dotSparse(qLDs+adr, x+offset, d, colind+adr, /*flg_unc1=*/0);
// multiple vectors
else {
for (int offset=0; offset < n*nv; offset+=nv) {
x[i+offset] -= mju_dotSparse(qLDs+adr, x+offset, d, colind+adr, /*flg_unc1=*/0);
}
}
}
}
Expand Down

0 comments on commit 5c4c79c

Please sign in to comment.