From 9736acbe53cef660ecf02ae6344de1a0762d7e9b Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 19 Sep 2024 09:11:22 +0200 Subject: [PATCH] Simplify CRBA by merging two jax.lax.cond branches --- src/jaxsim/rbda/crba.py | 58 ++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/jaxsim/rbda/crba.py b/src/jaxsim/rbda/crba.py index 1d45dbd8d..adb3506ae 100644 --- a/src/jaxsim/rbda/crba.py +++ b/src/jaxsim/rbda/crba.py @@ -94,48 +94,48 @@ def backward_pass( j = i - CarryInnerFn = tuple[jtp.Int, jtp.Matrix, jtp.Matrix] - carry_inner_fn = (j, Fi, M) + FakeWhileCarry = tuple[jtp.Int, jtp.Vector, jtp.Matrix] + fake_while_carry = (j, Fi, M) - def while_loop_body(carry: CarryInnerFn) -> CarryInnerFn: - j, Fi, M = carry + # This internal for loop implements the while loop of the CRBA algorithm + # to compute off-diagonal blocks of the mass matrix M. + # In pseudocode it is implemented as a while loop. However, in order to enable + # applying reverse-mode AD, we implement it as a nested for loop with a fixed + # number of iterations and a branching model to skip for loop iterations. + def fake_while_loop( + carry: FakeWhileCarry, i: jtp.Int + ) -> tuple[FakeWhileCarry, None]: - Fi = i_X_λi[j].T @ Fi - j = λ[j] - jj = j - 1 + def compute(carry: FakeWhileCarry) -> FakeWhileCarry: - M_ij = Fi.T @ S[j] + j, Fi, M = carry - M = M.at[ii + 6, jj + 6].set(M_ij.squeeze()) - M = M.at[jj + 6, ii + 6].set(M_ij.squeeze()) + Fi = i_X_λi[j].T @ Fi + j = λ[j] - return j, Fi, M + M_ij = Fi.T @ S[j] - # The following functions are part of a (rather messy) workaround for computing - # a while loop using a for loop with fixed number of iterations. - def inner_fn(carry: CarryInnerFn, k: jtp.Int) -> tuple[CarryInnerFn, None]: - def compute_inner(carry: CarryInnerFn) -> tuple[CarryInnerFn, None]: - j, _, _ = carry - out = jax.lax.cond( - pred=(λ[j] > 0), - true_fun=while_loop_body, - false_fun=lambda carry: carry, - operand=carry, - ) - return out, None + jj = j - 1 + M = M.at[ii + 6, jj + 6].set(M_ij.squeeze()) + M = M.at[jj + 6, ii + 6].set(M_ij.squeeze()) + + return j, Fi, M j, _, _ = carry - return jax.lax.cond( - pred=(k == j), - true_fun=compute_inner, - false_fun=lambda carry: (carry, None), + + j, Fi, M = jax.lax.cond( + pred=jnp.logical_and(i == λ[j], λ[j] > 0), + true_fun=compute, + false_fun=lambda carry: carry, operand=carry, ) + return (j, Fi, M), None + (j, Fi, M), _ = ( jax.lax.scan( - f=inner_fn, - init=carry_inner_fn, + f=fake_while_loop, + init=fake_while_carry, xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), ) if model.number_of_links() > 1