From 451609f52a6fe746ee163fde464ce2b3c44d4585 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 12 Sep 2024 16:02:37 +0200 Subject: [PATCH 1/8] Add recursive algorithm to compute the inverse of the mass matrix --- src/jaxsim/rbda/mass_inverse.py | 200 ++++++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 src/jaxsim/rbda/mass_inverse.py diff --git a/src/jaxsim/rbda/mass_inverse.py b/src/jaxsim/rbda/mass_inverse.py new file mode 100644 index 000000000..dbf16e384 --- /dev/null +++ b/src/jaxsim/rbda/mass_inverse.py @@ -0,0 +1,200 @@ +import jax +import jax.numpy as jnp +import jaxlie + +import jaxsim.api as js +import jaxsim.typing as jtp + +from . import utils + + +def mass_inverse( + model: js.model.JaxSimModel, + *, + base_position: jtp.VectorLike, + base_quaternion: jtp.VectorLike, + joint_positions: jtp.VectorLike, +) -> tuple[jtp.Vector, jtp.Vector]: + """ + Compute the inverse of the mass matrix using the Articulated Body Algorithm (ABA). + + Args: + model: The model to consider. + base_position: The position of the base link. + base_quaternion: The quaternion of the base link. + joint_positions: The positions of the joints. + + Returns: + The inverse of the free-floating mass matrix. + + Note: + The algorithm expects a quaternion with unit norm. + """ + + W_p_B, W_Q_B, s, W_v_WB, _, _, _, _, _, _ = utils.process_inputs( + model=model, + base_position=base_position, + base_quaternion=base_quaternion, + joint_positions=joint_positions, + ) + + W_v_WB = jnp.atleast_2d(W_v_WB).T + + # Get the 6D spatial inertia matrices of all links. + M = js.model.link_spatial_inertia_matrices(model=model) + + # Get the parent array λ(i). + # Note: λ(0) must not be used, it's initialized to -1. + λ = model.kin_dyn_parameters.parent_array + + # Compute the base transform. + W_H_B = jaxlie.SE3.from_rotation_and_translation( + rotation=jaxlie.SO3(wxyz=W_Q_B), + translation=W_p_B, + ) + + # Compute the parent-to-child adjoints and the motion subspaces of the joints. + # These transforms define the relative kinematics of the entire model, including + # the base transform for both floating-base and fixed-base models. + i_X_λi, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( + joint_positions=s, base_transform=W_H_B.as_matrix() + ) + + # Allocate buffers. + MA = jnp.zeros(shape=(model.number_of_links(), 6, 6)) + M_inv = jnp.zeros( + shape=( + model.number_of_links() + 6 * model.floating_base(), + model.number_of_links() + 6 * model.floating_base(), + ) + ) + + # Allocate the buffer of transforms link -> base. + i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6)) + i_X_0 = i_X_0.at[0].set(jnp.eye(6)) + + # Initialize base quantities. + if model.floating_base(): + + # Initialize the articulated-body inertia (Mᴬ) of base link. + MA_0 = M[0] + MA = MA.at[0].set(MA_0) + + # ====== + # Pass 1 + # ====== + + Pass1Carry = tuple[jtp.Matrix, jtp.Matrix] + pass_1_carry: Pass1Carry = (MA, i_X_0) + + # Propagate kinematics and initialize AB inertia and AB bias forces. + def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: + + v, c, MA, i_X_0 = carry + + # Initialize the articulated-body inertia. + MA_i = jnp.array(M[i]) + MA = MA.at[i].set(MA_i) + + # Compute the link-to-base transform. + i_Xi_0 = i_X_λi[i] @ i_X_0[λ[i]] + i_X_0 = i_X_0.at[i].set(i_Xi_0) + + return (v, c, MA, i_X_0), None + + (MA, i_X_0), _ = ( + jax.lax.scan( + f=loop_body_pass1, + init=pass_1_carry, + xs=jnp.arange(start=1, stop=model.number_of_links()), + ) + if model.number_of_links() > 1 + else [(MA, i_X_0), None] + ) + + # ====== + # Pass 2 + # ====== + + U = jnp.zeros_like(S) + d = jnp.zeros(shape=(model.number_of_links(), 1)) + + Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix] + pass_2_carry: Pass2Carry = (U, d, MA) + + def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: + + ( + U, + d, + MA, + ) = carry + + U_i = MA[i] @ S[i] + U = U.at[i].set(U_i) + + d_i = S[i].T @ U[i] + d = d.at[i].set(d_i.squeeze()) + + # Compute the articulated-body inertia and bias force of this link. + Ma = MA[i] - U[i] / d[i] @ U[i].T + + # Propagate them to the parent, handling the base link. + def propagate( + MA: tuple[jtp.Matrix, jtp.Matrix] + ) -> tuple[jtp.Matrix, jtp.Matrix]: + + MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i] + MA = MA.at[λ[i]].set(MA_λi) + + return MA + + MA = jax.lax.cond( + pred=jnp.logical_or(λ[i] != 0, model.floating_base()), + true_fun=propagate, + false_fun=lambda MA: MA, + operand=MA, + ) + + return (U, d, MA), None + + (U, d, MA), _ = ( + jax.lax.scan( + f=loop_body_pass2, + init=pass_2_carry, + xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), + ) + if model.number_of_links() > 1 + else [(U, d, MA), None] + ) + + # ====== + # Pass 3 + # ====== + + F = jnp.zeros_like(s) + + Pass3Carry = tuple[jtp.Matrix, jtp.Vector] + pass_3_carry = (M_inv, F) + + def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]: + + M_inv, F = carry + + return (M_inv, F), None + + (M_inv, F), _ = ( + jax.lax.scan( + f=loop_body_pass3, + init=pass_3_carry, + xs=jnp.arange(1, model.number_of_links()), + ) + if model.number_of_links() > 1 + else [(M_inv, F), None] + ) + + # ============== + # Adjust outputs + # ============== + + return M_inv From 548286b2080cc08ea1035af6dba247c4715abfa3 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 12 Sep 2024 20:43:26 +0200 Subject: [PATCH 2/8] [WIP] --- src/jaxsim/rbda/mass_inverse.py | 124 +++++++++++++++++++++++++------- 1 file changed, 100 insertions(+), 24 deletions(-) diff --git a/src/jaxsim/rbda/mass_inverse.py b/src/jaxsim/rbda/mass_inverse.py index dbf16e384..2e38f631c 100644 --- a/src/jaxsim/rbda/mass_inverse.py +++ b/src/jaxsim/rbda/mass_inverse.py @@ -47,12 +47,28 @@ def mass_inverse( # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array + # Build the ν(i) array, containing all the bodies in the subtree supported + # by joint i. Similarly to κ(i), we compute the boolean version νb(i), so that + # it can be stored in a matrix. + νb = jnp.zeros(shape=(λ.size, λ.size), dtype=bool) + + for i in reversed(range(len(λ))): + + νb = νb.at[i, i].set(True) + j = λ[i] + + while j > -1: + νb = νb.at[j, i].set(True) + j = λ[j] + # Compute the base transform. W_H_B = jaxlie.SE3.from_rotation_and_translation( rotation=jaxlie.SO3(wxyz=W_Q_B), translation=W_p_B, ) + νb = jnp.array(νb) + # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. @@ -66,6 +82,7 @@ def mass_inverse( shape=( model.number_of_links() + 6 * model.floating_base(), model.number_of_links() + 6 * model.floating_base(), + 1, ) ) @@ -90,7 +107,7 @@ def mass_inverse( # Propagate kinematics and initialize AB inertia and AB bias forces. def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: - v, c, MA, i_X_0 = carry + MA, i_X_0 = carry # Initialize the articulated-body inertia. MA_i = jnp.array(M[i]) @@ -100,7 +117,7 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: i_Xi_0 = i_X_λi[i] @ i_X_0[λ[i]] i_X_0 = i_X_0.at[i].set(i_Xi_0) - return (v, c, MA, i_X_0), None + return (MA, i_X_0), None (MA, i_X_0), _ = ( jax.lax.scan( @@ -118,17 +135,16 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: U = jnp.zeros_like(S) d = jnp.zeros(shape=(model.number_of_links(), 1)) + F = jnp.zeros(shape=(6, model.number_of_links())) - Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix] - pass_2_carry: Pass2Carry = (U, d, MA) + Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix] + pass_2_carry: Pass2Carry = (U, d, M_inv, MA, F) def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: - ( - U, - d, - MA, - ) = carry + (U, d, M_inv, MA, F) = carry + + ν = jnp.where(νb[i], size=model.number_of_links())[0] U_i = MA[i] @ S[i] U = U.at[i].set(U_i) @@ -139,58 +155,118 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: # Compute the articulated-body inertia and bias force of this link. Ma = MA[i] - U[i] / d[i] @ U[i].T + M_inv_ii = 1 / d[i] + M_inv = M_inv.at[i, i].set(M_inv_ii) + + M_inv_iν = M_inv[i, ν[i]] - S[i].T @ F[:, ν[i]].squeeze() / d[i].T + M_inv = M_inv.at[i, ν[i]].set(M_inv_iν) + # Propagate them to the parent, handling the base link. def propagate( - MA: tuple[jtp.Matrix, jtp.Matrix] + MA_F: tuple[jtp.Matrix, jtp.Matrix] ) -> tuple[jtp.Matrix, jtp.Matrix]: + MA, F = MA_F + + Fa_λi = F[:, ν[i]] + U[i] @ M_inv[i, ν[i]] + F = F.at[:, ν[i]].set(Fa_λi) MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i] MA = MA.at[λ[i]].set(MA_λi) - return MA + return MA, F - MA = jax.lax.cond( + MA, F = jax.lax.cond( pred=jnp.logical_or(λ[i] != 0, model.floating_base()), true_fun=propagate, - false_fun=lambda MA: MA, - operand=MA, + false_fun=lambda MA_F: MA_F, + operand=(MA, F), ) - return (U, d, MA), None + return (U, d, M_inv, MA, F), None - (U, d, MA), _ = ( + (U, d, M_inv, MA, F), _ = ( jax.lax.scan( f=loop_body_pass2, init=pass_2_carry, xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), ) if model.number_of_links() > 1 - else [(U, d, MA), None] + else [(U, d, M_inv, MA, F), None] ) # ====== # Pass 3 # ====== - F = jnp.zeros_like(s) + P = jnp.zeros( + shape=( + model.number_of_links(), + model.number_of_links(), + model.number_of_links(), + ) + ) - Pass3Carry = tuple[jtp.Matrix, jtp.Vector] - pass_3_carry = (M_inv, F) + Pass3Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix] + pass_3_carry = (U, M_inv, P) def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]: - M_inv, F = carry + U, M_inv, P = carry + + mask = jnp.arange(P.shape[1]) >= i # equivalent to [i, i:] + + def propagate_M_inv(M_inv: jtp.Matrix) -> jtp.Matrix: + P_ii = jax.lax.dynamic_slice( + P, (i, i - P.shape[1], P.shape[2]), (P.shape[0], 1) * mask + ) + M_inv_ii = jax.lax.dynamic_slice( + M_inv.squeeze(), (i, i - M_inv.squeeze().shape[0]), i_X_λi[i].shape + ) + M_inv_ii = M_inv_ii.at[:].set(M_inv_ii - U[i].T @ i_X_λi[i] @ P_ii / d[i]) + jax.lax.dynamic_update_slice(M, M_inv_ii, (i, i)) + + return M_inv + + M_inv = jax.lax.cond( + pred=jnp.logical_or(λ[i] != 0, model.floating_base()), + true_fun=propagate_M_inv, + false_fun=lambda M_inv: M_inv, + operand=M_inv, + ) + + M_inv_ii = jax.lax.dynamic_slice( + M_inv, (i, i - d.shape[0], 1), (1, d[i].shape - i, 1) + ) + + P_i = S[i].T @ M_inv_ii + P = P.at[i].set(P_i.squeeze()) + + def propagate_P(P: jtp.Vector) -> jtp.Vector: + P_λii = jax.lax.dynamic_slice(P, (λ[i], i), (1, i)) + P_iii = jax.lax.dynamic_slice(P, (i, i), (1, i)) + + P_iii = P_iii.at[:].set(P_iii + i_X_λi[i].T @ P_λii) + jax.lax.dynamic_update_slice(P, P_iii, (i, i)) + + return P + + P = jax.lax.cond( + pred=jnp.logical_or(λ[i] != 0, model.floating_base()), + true_fun=propagate_P, + false_fun=lambda P: P, + operand=P, + ) - return (M_inv, F), None + return (U, M_inv, P), None - (M_inv, F), _ = ( + (U, M_inv, P), _ = ( jax.lax.scan( f=loop_body_pass3, init=pass_3_carry, xs=jnp.arange(1, model.number_of_links()), ) if model.number_of_links() > 1 - else [(M_inv, F), None] + else [(U, M_inv, P), None] ) # ============== From 7ac00b71cac10f80eddec72a27d50c9a5723bc72 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Fri, 13 Sep 2024 17:32:16 +0200 Subject: [PATCH 3/8] [ci skip] Use boolean mask for slicing --- src/jaxsim/rbda/mass_inverse.py | 38 +++++++++++---------------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/src/jaxsim/rbda/mass_inverse.py b/src/jaxsim/rbda/mass_inverse.py index 2e38f631c..61ee70be1 100644 --- a/src/jaxsim/rbda/mass_inverse.py +++ b/src/jaxsim/rbda/mass_inverse.py @@ -154,6 +154,7 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: # Compute the articulated-body inertia and bias force of this link. Ma = MA[i] - U[i] / d[i] @ U[i].T + Fa = F[i, :, ν[i]] + U[i] @ M_inv[i, ν[i]] M_inv_ii = 1 / d[i] M_inv = M_inv.at[i, i].set(M_inv_ii) @@ -167,7 +168,7 @@ def propagate( ) -> tuple[jtp.Matrix, jtp.Matrix]: MA, F = MA_F - Fa_λi = F[:, ν[i]] + U[i] @ M_inv[i, ν[i]] + Fa_λi = F[λ[i], :, ν[i]] + i_X_λi[i].T @ Fa F = F.at[:, ν[i]].set(Fa_λi) MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i] @@ -198,13 +199,7 @@ def propagate( # Pass 3 # ====== - P = jnp.zeros( - shape=( - model.number_of_links(), - model.number_of_links(), - model.number_of_links(), - ) - ) + P = jnp.zeros_like(F) Pass3Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix] pass_3_carry = (U, M_inv, P) @@ -213,17 +208,13 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]: U, M_inv, P = carry - mask = jnp.arange(P.shape[1]) >= i # equivalent to [i, i:] + mask = jnp.arange(P.shape[1]) >= i def propagate_M_inv(M_inv: jtp.Matrix) -> jtp.Matrix: - P_ii = jax.lax.dynamic_slice( - P, (i, i - P.shape[1], P.shape[2]), (P.shape[0], 1) * mask + P_λi = jnp.where(mask, i_X_λi[i].T @ P[λ[i], i], P[λ[i], i]) + M_inv = M_inv.at[i].set( + jnp.where(mask, M_inv[i] - U[i].T @ P_λi / d[i], M_inv) ) - M_inv_ii = jax.lax.dynamic_slice( - M_inv.squeeze(), (i, i - M_inv.squeeze().shape[0]), i_X_λi[i].shape - ) - M_inv_ii = M_inv_ii.at[:].set(M_inv_ii - U[i].T @ i_X_λi[i] @ P_ii / d[i]) - jax.lax.dynamic_update_slice(M, M_inv_ii, (i, i)) return M_inv @@ -234,19 +225,14 @@ def propagate_M_inv(M_inv: jtp.Matrix) -> jtp.Matrix: operand=M_inv, ) - M_inv_ii = jax.lax.dynamic_slice( - M_inv, (i, i - d.shape[0], 1), (1, d[i].shape - i, 1) - ) + M_inv_ii = M_inv[i] * mask - P_i = S[i].T @ M_inv_ii - P = P.at[i].set(P_i.squeeze()) + P_ii = S[i].T @ M_inv_ii + P = P.at[i].set(P_ii.squeeze()) def propagate_P(P: jtp.Vector) -> jtp.Vector: - P_λii = jax.lax.dynamic_slice(P, (λ[i], i), (1, i)) - P_iii = jax.lax.dynamic_slice(P, (i, i), (1, i)) - - P_iii = P_iii.at[:].set(P_iii + i_X_λi[i].T @ P_λii) - jax.lax.dynamic_update_slice(P, P_iii, (i, i)) + P_λi = jnp.where(mask, i_X_λi[i].T @ P[λ[i], i], P[λ[i], i]) + P = P.at[i].set(jnp.where(mask, P[i] + P_λi, P[i])) return P From e042c22b62d9b0bdc4f21a6530e5718418fe8a33 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 18 Sep 2024 19:10:53 +0200 Subject: [PATCH 4/8] [WIP] --- src/jaxsim/rbda/mass_inverse.py | 116 +++++++++++++++++++------------- 1 file changed, 70 insertions(+), 46 deletions(-) diff --git a/src/jaxsim/rbda/mass_inverse.py b/src/jaxsim/rbda/mass_inverse.py index 61ee70be1..3dcc760f7 100644 --- a/src/jaxsim/rbda/mass_inverse.py +++ b/src/jaxsim/rbda/mass_inverse.py @@ -78,13 +78,6 @@ def mass_inverse( # Allocate buffers. MA = jnp.zeros(shape=(model.number_of_links(), 6, 6)) - M_inv = jnp.zeros( - shape=( - model.number_of_links() + 6 * model.floating_base(), - model.number_of_links() + 6 * model.floating_base(), - 1, - ) - ) # Allocate the buffer of transforms link -> base. i_X_0 = jnp.zeros(shape=(model.number_of_links(), 6, 6)) @@ -92,7 +85,6 @@ def mass_inverse( # Initialize base quantities. if model.floating_base(): - # Initialize the articulated-body inertia (Mᴬ) of base link. MA_0 = M[0] MA = MA.at[0].set(MA_0) @@ -135,7 +127,25 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: U = jnp.zeros_like(S) d = jnp.zeros(shape=(model.number_of_links(), 1)) - F = jnp.zeros(shape=(6, model.number_of_links())) + M_inv = jnp.zeros( + shape=( + model.number_of_joints() + 6 * model.floating_base(), + model.number_of_joints() + 6 * model.floating_base(), + 1, + ) + ) + + if model.number_of_joints() == 0: + M_inv_0 = jnp.linalg.solve(MA[0], jnp.eye(6)) + M_inv = M_inv.at[:].set(M_inv_0) + + F = jnp.zeros( + shape=( + model.number_of_links(), + 6, + model.number_of_links(), + ) + ) Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix] pass_2_carry: Pass2Carry = (U, d, M_inv, MA, F) @@ -152,14 +162,10 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: d_i = S[i].T @ U[i] d = d.at[i].set(d_i.squeeze()) - # Compute the articulated-body inertia and bias force of this link. - Ma = MA[i] - U[i] / d[i] @ U[i].T - Fa = F[i, :, ν[i]] + U[i] @ M_inv[i, ν[i]] - - M_inv_ii = 1 / d[i] - M_inv = M_inv.at[i, i].set(M_inv_ii) + M_inv_i = 1 / d[i] + M_inv = M_inv.at[i, i].set(M_inv_i) - M_inv_iν = M_inv[i, ν[i]] - S[i].T @ F[:, ν[i]].squeeze() / d[i].T + M_inv_iν = M_inv[i, ν[i]] - S[i].T @ F[i, :, ν[i]].squeeze() / d[i].T M_inv = M_inv.at[i, ν[i]].set(M_inv_iν) # Propagate them to the parent, handling the base link. @@ -168,10 +174,14 @@ def propagate( ) -> tuple[jtp.Matrix, jtp.Matrix]: MA, F = MA_F - Fa_λi = F[λ[i], :, ν[i]] + i_X_λi[i].T @ Fa - F = F.at[:, ν[i]].set(Fa_λi) + # Compute the articulated-body inertia and bias force of this link. + Ma_i = MA[i] - U[i] / d[i] @ U[i].T + Fa_i = F[i, :, ν[i]] + U[i] @ M_inv[i, ν[i]] - MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i] + Fa_λi = F[λ[i], :, ν[i]] + i_X_λi[i].T @ Fa_i + F = F.at[λ[i], :, ν[i]].set(Fa_λi) + + MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma_i @ i_X_λi[i] MA = MA.at[λ[i]].set(MA_λi) return MA, F @@ -185,21 +195,29 @@ def propagate( return (U, d, M_inv, MA, F), None - (U, d, M_inv, MA, F), _ = ( - jax.lax.scan( - f=loop_body_pass2, - init=pass_2_carry, - xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), + with jax.disable_jit(True): + (U, d, M_inv, MA, F), _ = ( + jax.lax.scan( + f=loop_body_pass2, + init=pass_2_carry, + xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), + ) + if model.number_of_links() > 1 + else [(U, d, M_inv, MA, F), None] ) - if model.number_of_links() > 1 - else [(U, d, M_inv, MA, F), None] - ) # ====== # Pass 3 # ====== - P = jnp.zeros_like(F) + P = jnp.zeros( + shape=( + model.number_of_joints(), + model.number_of_joints(), + 6, + model.number_of_joints() + 6 * model.floating_base(), + ) + ) Pass3Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix] pass_3_carry = (U, M_inv, P) @@ -208,12 +226,17 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]: U, M_inv, P = carry - mask = jnp.arange(P.shape[1]) >= i + mask = jnp.arange(P.shape[-1]) >= i + mask_M = jnp.atleast_2d(mask).T def propagate_M_inv(M_inv: jtp.Matrix) -> jtp.Matrix: - P_λi = jnp.where(mask, i_X_λi[i].T @ P[λ[i], i], P[λ[i], i]) + M_inv = M_inv.at[i].set( - jnp.where(mask, M_inv[i] - U[i].T @ P_λi / d[i], M_inv) + jnp.where( + mask_M, + M_inv[i] - (U[i].T @ i_X_λi[i].T @ P[λ[i], i]).T / d[i], + M_inv[i], + ) ) return M_inv @@ -225,14 +248,14 @@ def propagate_M_inv(M_inv: jtp.Matrix) -> jtp.Matrix: operand=M_inv, ) - M_inv_ii = M_inv[i] * mask - - P_ii = S[i].T @ M_inv_ii - P = P.at[i].set(P_ii.squeeze()) + P_ii = jnp.where(mask, S[i] @ M_inv[i].T, P[i, i]) + P = P.at[i].set(P_ii) def propagate_P(P: jtp.Vector) -> jtp.Vector: - P_λi = jnp.where(mask, i_X_λi[i].T @ P[λ[i], i], P[λ[i], i]) - P = P.at[i].set(jnp.where(mask, P[i] + P_λi, P[i])) + + P = P.at[i, i].set( + jnp.where(mask, P[i, i] + i_X_λi[i].T @ P[λ[i], i], P[i, i]) + ) return P @@ -245,18 +268,19 @@ def propagate_P(P: jtp.Vector) -> jtp.Vector: return (U, M_inv, P), None - (U, M_inv, P), _ = ( - jax.lax.scan( - f=loop_body_pass3, - init=pass_3_carry, - xs=jnp.arange(1, model.number_of_links()), + with jax.disable_jit(True): + (U, M_inv, P), _ = ( + jax.lax.scan( + f=loop_body_pass3, + init=pass_3_carry, + xs=jnp.arange(1, model.number_of_links()), + ) + if model.number_of_links() > 1 + else [(U, M_inv, P), None] ) - if model.number_of_links() > 1 - else [(U, M_inv, P), None] - ) # ============== # Adjust outputs # ============== - return M_inv + return M_inv.squeeze() From c29da9eec6d7ab7ee985261f26e6929f880190ab Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 18 Sep 2024 19:47:54 +0200 Subject: [PATCH 5/8] [WIP] Reduced dimensions --- src/jaxsim/rbda/mass_inverse.py | 90 ++++++++++++++------------------- 1 file changed, 39 insertions(+), 51 deletions(-) diff --git a/src/jaxsim/rbda/mass_inverse.py b/src/jaxsim/rbda/mass_inverse.py index 3dcc760f7..d1399f450 100644 --- a/src/jaxsim/rbda/mass_inverse.py +++ b/src/jaxsim/rbda/mass_inverse.py @@ -131,7 +131,6 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: shape=( model.number_of_joints() + 6 * model.floating_base(), model.number_of_joints() + 6 * model.floating_base(), - 1, ) ) @@ -141,9 +140,9 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: F = jnp.zeros( shape=( - model.number_of_links(), + model.number_of_joints(), 6, - model.number_of_links(), + model.number_of_joints() + 6 * model.floating_base(), ) ) @@ -154,7 +153,8 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: (U, d, M_inv, MA, F) = carry - ν = jnp.where(νb[i], size=model.number_of_links())[0] + # ν = jnp.where(νb[i], size=model.number_of_links())[0] + ii = i - 1 U_i = MA[i] @ S[i] U = U.at[i].set(U_i) @@ -162,11 +162,10 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: d_i = S[i].T @ U[i] d = d.at[i].set(d_i.squeeze()) - M_inv_i = 1 / d[i] - M_inv = M_inv.at[i, i].set(M_inv_i) + M_inv_νν = -S[i].T.squeeze() @ F[i].squeeze() / d[i].T - M_inv_iν = M_inv[i, ν[i]] - S[i].T @ F[i, :, ν[i]].squeeze() / d[i].T - M_inv = M_inv.at[i, ν[i]].set(M_inv_iν) + M_inv = M_inv.at[ii, ii].set(M_inv[i, i] + 1 / d[i].squeeze()) + M_inv = M_inv.at[ii].set(M_inv_νν.T) # Propagate them to the parent, handling the base link. def propagate( @@ -175,14 +174,11 @@ def propagate( MA, F = MA_F # Compute the articulated-body inertia and bias force of this link. - Ma_i = MA[i] - U[i] / d[i] @ U[i].T - Fa_i = F[i, :, ν[i]] + U[i] @ M_inv[i, ν[i]] - - Fa_λi = F[λ[i], :, ν[i]] + i_X_λi[i].T @ Fa_i - F = F.at[λ[i], :, ν[i]].set(Fa_λi) + Ma = MA[i] - U[i].squeeze() / d[i].squeeze() @ U[i].T.squeeze() + Fa = F[i] + U[i] @ jnp.atleast_2d(M_inv[ii]) - MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma_i @ i_X_λi[i] - MA = MA.at[λ[i]].set(MA_λi) + F = F.at[λ[i]].set(F[λ[i]] + i_X_λi[i].T @ Fa) + MA = MA.at[λ[i]].set(MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]) return MA, F @@ -195,50 +191,34 @@ def propagate( return (U, d, M_inv, MA, F), None - with jax.disable_jit(True): - (U, d, M_inv, MA, F), _ = ( - jax.lax.scan( - f=loop_body_pass2, - init=pass_2_carry, - xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), - ) - if model.number_of_links() > 1 - else [(U, d, M_inv, MA, F), None] + (U, d, M_inv, MA, F), _ = ( + jax.lax.scan( + f=loop_body_pass2, + init=pass_2_carry, + xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), ) + if model.number_of_links() > 1 + else [(U, d, M_inv, MA, F), None] + ) # ====== # Pass 3 # ====== - P = jnp.zeros( - shape=( - model.number_of_joints(), - model.number_of_joints(), - 6, - model.number_of_joints() + 6 * model.floating_base(), - ) - ) + P = jnp.zeros_like(F) Pass3Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix] pass_3_carry = (U, M_inv, P) def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]: + ii = i - 1 U, M_inv, P = carry mask = jnp.arange(P.shape[-1]) >= i - mask_M = jnp.atleast_2d(mask).T def propagate_M_inv(M_inv: jtp.Matrix) -> jtp.Matrix: - M_inv = M_inv.at[i].set( - jnp.where( - mask_M, - M_inv[i] - (U[i].T @ i_X_λi[i].T @ P[λ[i], i]).T / d[i], - M_inv[i], - ) - ) - return M_inv M_inv = jax.lax.cond( @@ -248,22 +228,30 @@ def propagate_M_inv(M_inv: jtp.Matrix) -> jtp.Matrix: operand=M_inv, ) - P_ii = jnp.where(mask, S[i] @ M_inv[i].T, P[i, i]) - P = P.at[i].set(P_ii) + P_i = jnp.where(mask, S[i] @ jnp.atleast_2d(M_inv[ii]), P[i]) + P = P.at[i].set(P_i) - def propagate_P(P: jtp.Vector) -> jtp.Vector: + def propagate_M_P(M_inv_P: tuple[jtp.Matrix, jtp.Array]) -> jtp.Vector: + M_inv, P = M_inv_P - P = P.at[i, i].set( - jnp.where(mask, P[i, i] + i_X_λi[i].T @ P[λ[i], i], P[i, i]) + M_inv = M_inv.at[ii].set( + jnp.where( + mask, + M_inv[ii] + - (U[i].T.squeeze() @ i_X_λi[i].T @ P[λ[i]]).T / d[i].squeeze(), + M_inv[ii], + ) ) - return P + P = P.at[i].set(jnp.where(mask, P[i] + i_X_λi[i].T @ P[λ[i]], P[i])) + + return M_inv, P - P = jax.lax.cond( + M_inv, P = jax.lax.cond( pred=jnp.logical_or(λ[i] != 0, model.floating_base()), - true_fun=propagate_P, - false_fun=lambda P: P, - operand=P, + true_fun=propagate_M_P, + false_fun=lambda M_inv_P: (M_inv, P), + operand=(M_inv, P), ) return (U, M_inv, P), None From 0d0199fd34aedd09863e6f76b7b94598f3f600b8 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 23 Sep 2024 09:41:47 +0200 Subject: [PATCH 6/8] wip --- src/jaxsim/rbda/mass_inverse.py | 76 ++++++++++----------------------- 1 file changed, 23 insertions(+), 53 deletions(-) diff --git a/src/jaxsim/rbda/mass_inverse.py b/src/jaxsim/rbda/mass_inverse.py index d1399f450..91c95676e 100644 --- a/src/jaxsim/rbda/mass_inverse.py +++ b/src/jaxsim/rbda/mass_inverse.py @@ -47,28 +47,12 @@ def mass_inverse( # Note: λ(0) must not be used, it's initialized to -1. λ = model.kin_dyn_parameters.parent_array - # Build the ν(i) array, containing all the bodies in the subtree supported - # by joint i. Similarly to κ(i), we compute the boolean version νb(i), so that - # it can be stored in a matrix. - νb = jnp.zeros(shape=(λ.size, λ.size), dtype=bool) - - for i in reversed(range(len(λ))): - - νb = νb.at[i, i].set(True) - j = λ[i] - - while j > -1: - νb = νb.at[j, i].set(True) - j = λ[j] - # Compute the base transform. W_H_B = jaxlie.SE3.from_rotation_and_translation( rotation=jaxlie.SO3(wxyz=W_Q_B), translation=W_p_B, ) - νb = jnp.array(νb) - # Compute the parent-to-child adjoints and the motion subspaces of the joints. # These transforms define the relative kinematics of the entire model, including # the base transform for both floating-base and fixed-base models. @@ -151,10 +135,8 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: - (U, d, M_inv, MA, F) = carry - - # ν = jnp.where(νb[i], size=model.number_of_links())[0] ii = i - 1 + (U, d, M_inv, MA, F) = carry U_i = MA[i] @ S[i] U = U.at[i].set(U_i) @@ -162,10 +144,10 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: d_i = S[i].T @ U[i] d = d.at[i].set(d_i.squeeze()) - M_inv_νν = -S[i].T.squeeze() @ F[i].squeeze() / d[i].T + M_inv_ii = -S[i].T @ F[i] / d[i] + M_inv = M_inv.at[ii].set(M_inv_ii.squeeze()) M_inv = M_inv.at[ii, ii].set(M_inv[i, i] + 1 / d[i].squeeze()) - M_inv = M_inv.at[ii].set(M_inv_νν.T) # Propagate them to the parent, handling the base link. def propagate( @@ -174,8 +156,8 @@ def propagate( MA, F = MA_F # Compute the articulated-body inertia and bias force of this link. - Ma = MA[i] - U[i].squeeze() / d[i].squeeze() @ U[i].T.squeeze() - Fa = F[i] + U[i] @ jnp.atleast_2d(M_inv[ii]) + Fa = F[i] + U[i] @ M_inv[ii][jnp.newaxis, :] + Ma = MA[i] - U[i] / d[i] @ U[i].T F = F.at[λ[i]].set(F[λ[i]] + i_X_λi[i].T @ Fa) MA = MA.at[λ[i]].set(MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]) @@ -191,15 +173,16 @@ def propagate( return (U, d, M_inv, MA, F), None - (U, d, M_inv, MA, F), _ = ( - jax.lax.scan( - f=loop_body_pass2, - init=pass_2_carry, - xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), + with jax.disable_jit(True): + (U, d, M_inv, MA, F), _ = ( + jax.lax.scan( + f=loop_body_pass2, + init=pass_2_carry, + xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), + ) + if model.number_of_links() > 1 + else [(U, d, M_inv, MA, F), None] ) - if model.number_of_links() > 1 - else [(U, d, M_inv, MA, F), None] - ) # ====== # Pass 3 @@ -215,35 +198,18 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]: ii = i - 1 U, M_inv, P = carry - mask = jnp.arange(P.shape[-1]) >= i - - def propagate_M_inv(M_inv: jtp.Matrix) -> jtp.Matrix: - - return M_inv - - M_inv = jax.lax.cond( - pred=jnp.logical_or(λ[i] != 0, model.floating_base()), - true_fun=propagate_M_inv, - false_fun=lambda M_inv: M_inv, - operand=M_inv, - ) - - P_i = jnp.where(mask, S[i] @ jnp.atleast_2d(M_inv[ii]), P[i]) + P_i = S[i] @ jnp.atleast_2d(M_inv[ii]) P = P.at[i].set(P_i) def propagate_M_P(M_inv_P: tuple[jtp.Matrix, jtp.Array]) -> jtp.Vector: M_inv, P = M_inv_P M_inv = M_inv.at[ii].set( - jnp.where( - mask, - M_inv[ii] - - (U[i].T.squeeze() @ i_X_λi[i].T @ P[λ[i]]).T / d[i].squeeze(), - M_inv[ii], - ) + M_inv[ii] + - (U[i].T.squeeze() @ i_X_λi[i].T @ P[λ[i]]).T / d[i].squeeze(), ) - P = P.at[i].set(jnp.where(mask, P[i] + i_X_λi[i].T @ P[λ[i]], P[i])) + P = P.at[i].set(i_X_λi[i].T @ P[λ[i]]) return M_inv, P @@ -270,5 +236,9 @@ def propagate_M_P(M_inv_P: tuple[jtp.Matrix, jtp.Array]) -> jtp.Vector: # ============== # Adjust outputs # ============== + M_inv = M_inv.squeeze() + + # Mirror the upper triangle to the lower triangle. + M_inv = jnp.triu(M_inv) + jnp.triu(M_inv, k=1).T - return M_inv.squeeze() + return M_inv From bbd3f682ea5b3feae5618f91cd841fca4a57f6e4 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 23 Sep 2024 17:39:38 +0200 Subject: [PATCH 7/8] wip --- src/jaxsim/rbda/mass_inverse.py | 55 ++++++++++----------------------- 1 file changed, 17 insertions(+), 38 deletions(-) diff --git a/src/jaxsim/rbda/mass_inverse.py b/src/jaxsim/rbda/mass_inverse.py index 91c95676e..1b9e3727f 100644 --- a/src/jaxsim/rbda/mass_inverse.py +++ b/src/jaxsim/rbda/mass_inverse.py @@ -77,33 +77,9 @@ def mass_inverse( # Pass 1 # ====== - Pass1Carry = tuple[jtp.Matrix, jtp.Matrix] - pass_1_carry: Pass1Carry = (MA, i_X_0) - # Propagate kinematics and initialize AB inertia and AB bias forces. - def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: - - MA, i_X_0 = carry - - # Initialize the articulated-body inertia. - MA_i = jnp.array(M[i]) - MA = MA.at[i].set(MA_i) - - # Compute the link-to-base transform. - i_Xi_0 = i_X_λi[i] @ i_X_0[λ[i]] - i_X_0 = i_X_0.at[i].set(i_Xi_0) - - return (MA, i_X_0), None - - (MA, i_X_0), _ = ( - jax.lax.scan( - f=loop_body_pass1, - init=pass_1_carry, - xs=jnp.arange(start=1, stop=model.number_of_links()), - ) - if model.number_of_links() > 1 - else [(MA, i_X_0), None] - ) + MA = jnp.array(M) + i_X_0 = jax.vmap(lambda i_X_λi, i_X_0: i_X_λi @ i_X_0)(i_X_λi[1:], i_X_0[1:]) # ====== # Pass 2 @@ -144,23 +120,26 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: d_i = S[i].T @ U[i] d = d.at[i].set(d_i.squeeze()) - M_inv_ii = -S[i].T @ F[i] / d[i] + M_inv_i = -S[i].T @ F[ii] / d[i] + M_inv = M_inv.at[i].set(M_inv_i.squeeze()) - M_inv = M_inv.at[ii].set(M_inv_ii.squeeze()) M_inv = M_inv.at[ii, ii].set(M_inv[i, i] + 1 / d[i].squeeze()) + # Compute the articulated-body inertia and bias force of this link. + Ma = MA[i] - U[i] / d[i] @ U[i].T + Fa = F[i] + U[i] * M_inv[ii, ii] + # Propagate them to the parent, handling the base link. def propagate( MA_F: tuple[jtp.Matrix, jtp.Matrix] ) -> tuple[jtp.Matrix, jtp.Matrix]: MA, F = MA_F - # Compute the articulated-body inertia and bias force of this link. - Fa = F[i] + U[i] @ M_inv[ii][jnp.newaxis, :] - Ma = MA[i] - U[i] / d[i] @ U[i].T + Ma_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i] + MA = MA.at[λ[i]].set(Ma_λi) - F = F.at[λ[i]].set(F[λ[i]] + i_X_λi[i].T @ Fa) - MA = MA.at[λ[i]].set(MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]) + Fa_λi = F[λ[i]] + i_X_λi[i].T @ Fa + F = F.at[λ[i]].set(Fa_λi) return MA, F @@ -198,16 +177,16 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]: ii = i - 1 U, M_inv, P = carry - P_i = S[i] @ jnp.atleast_2d(M_inv[ii]) + Ma_inv = i_X_λi[i].T @ P[λ[i]] + + P_i = S[i] @ M_inv[ii, jnp.newaxis] P = P.at[i].set(P_i) def propagate_M_P(M_inv_P: tuple[jtp.Matrix, jtp.Array]) -> jtp.Vector: M_inv, P = M_inv_P - M_inv = M_inv.at[ii].set( - M_inv[ii] - - (U[i].T.squeeze() @ i_X_λi[i].T @ P[λ[i]]).T / d[i].squeeze(), - ) + M_inv_i = M_inv[ii] - U[i].T @ Ma_inv / d[i] + M_inv = M_inv.at[ii].set(M_inv_i.squeeze()) P = P.at[i].set(i_X_λi[i].T @ P[λ[i]]) From c0b6295c5fe94052e19bb2810c615d68be634f5e Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Mon, 23 Sep 2024 17:39:46 +0200 Subject: [PATCH 8/8] [ci skip] [WIP] Use `jax.vmap` in ABA --- src/jaxsim/rbda/aba.py | 59 ++++++++++++++---------------------------- 1 file changed, 20 insertions(+), 39 deletions(-) diff --git a/src/jaxsim/rbda/aba.py b/src/jaxsim/rbda/aba.py index b01f46698..548e43a78 100644 --- a/src/jaxsim/rbda/aba.py +++ b/src/jaxsim/rbda/aba.py @@ -125,9 +125,9 @@ def aba( pass_1_carry: Pass1Carry = (v, c, MA, pA, i_X_0) # Propagate kinematics and initialize AB inertia and AB bias forces. - def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: - - ii = i - 1 + def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry]: + ii = i + i = i + 1 v, c, MA, pA, i_X_0 = carry # Project the joint velocity into its motion subspace. @@ -155,16 +155,10 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(W_f[i]) pA = pA.at[i].set(pA_i) - return (v, c, MA, pA, i_X_0), None + return (v, c, MA, pA, i_X_0) - (v, c, MA, pA, i_X_0), _ = ( - jax.lax.scan( - f=loop_body_pass1, - init=pass_1_carry, - xs=jnp.arange(start=1, stop=model.number_of_links()), - ) - if model.number_of_links() > 1 - else [(v, c, MA, pA, i_X_0), None] + (v, c, MA, pA, i_X_0) = jax.vmap(loop_body_pass1)( + pass_1_carry, jnp.arange(start=0, stop=model.number_of_links()) ) # ====== @@ -178,9 +172,9 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix] pass_2_carry: Pass2Carry = (U, d, u, MA, pA) - def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: - - ii = i - 1 + def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry]: + ii = i + i = i + 1 U, d, u, MA, pA = carry U_i = MA[i] @ S[i] @@ -198,9 +192,8 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: # Propagate them to the parent, handling the base link. def propagate( - MA_pA: tuple[jtp.Matrix, jtp.Matrix] - ) -> tuple[jtp.Matrix, jtp.Matrix]: - + MA_pA: tuple[jtp.Array, jtp.Array] + ) -> tuple[jtp.Array, jtp.Array]: MA, pA = MA_pA MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i] @@ -218,16 +211,10 @@ def propagate( operand=(MA, pA), ) - return (U, d, u, MA, pA), None + return (U, d, u, MA, pA) - (U, d, u, MA, pA), _ = ( - jax.lax.scan( - f=loop_body_pass2, - init=pass_2_carry, - xs=jnp.flip(jnp.arange(start=1, stop=model.number_of_links())), - ) - if model.number_of_links() > 1 - else [(U, d, u, MA, pA), None] + (U, d, u, MA, pA) = jax.vmap(loop_body_pass2)( + pass_2_carry, jnp.flip(jnp.arange(start=0, stop=model.number_of_links())) ) # ====== @@ -245,9 +232,9 @@ def propagate( Pass3Carry = tuple[jtp.Matrix, jtp.Vector] pass_3_carry = (a, s̈) - def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]: - - ii = i - 1 + def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry]: + ii = i + i = i + 1 a, s̈ = carry # Propagate the link acceleration. @@ -261,16 +248,10 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]: a_i = a_i + S[i] * s̈[ii] a = a.at[i].set(a_i) - return (a, s̈), None + return (a, s̈) - (a, s̈), _ = ( - jax.lax.scan( - f=loop_body_pass3, - init=pass_3_carry, - xs=jnp.arange(1, model.number_of_links()), - ) - if model.number_of_links() > 1 - else [(a, s̈), None] + (a, s̈) = jax.vmap(loop_body_pass3)( + pass_3_carry, jnp.arange(1, model.number_of_links()) ) # ==============