From 5d130e3ae43b470ec72ddc61800600e9214a4efc Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 12:46:12 +0200 Subject: [PATCH 01/16] Add pytorch batch interface --- src/adam/pytorch/computation_batch.py | 426 ++++++++++++++++++++++++++ 1 file changed, 426 insertions(+) create mode 100644 src/adam/pytorch/computation_batch.py diff --git a/src/adam/pytorch/computation_batch.py b/src/adam/pytorch/computation_batch.py new file mode 100644 index 00000000..1e7753a9 --- /dev/null +++ b/src/adam/pytorch/computation_batch.py @@ -0,0 +1,426 @@ +# Copyright (C) 2024 Istituto Italiano di Tecnologia (IIT). All rights reserved. +# This software may be modified and distributed under the terms of the +# GNU Lesser General Public License v2.1 or any later version. + +import jax.numpy as jnp +import numpy as np + +from adam.core.constants import Representations +from adam.core.rbd_algorithms import RBDAlgorithms +from adam.jax.jax_like import SpatialMath +from adam.model import Model, URDFModelFactory +from jax2torch import jax2torch +import torch +import jax + + +class KinDynComputationsBatch: + """This is a small class that retrieves robot quantities using Jax for Floating Base systems. + These functions are vmapped and jit compiled and passed to jax2torch to convert them to torch functions. + """ + + def __init__( + self, + urdfstring: str, + joints_name_list: list = None, + root_link: str = "root_link", + gravity: np.array = jnp.array([0, 0, -9.80665, 0, 0, 0]), + ) -> None: + """ + Args: + urdfstring (str): path of the urdf + joints_name_list (list): list of the actuated joints + root_link (str, optional): the first link. Defaults to 'root_link'. + """ + math = SpatialMath() + factory = URDFModelFactory(path=urdfstring, math=math) + model = Model.build(factory=factory, joints_name_list=joints_name_list) + self.rbdalgos = RBDAlgorithms(model=model, math=math) + self.NDoF = self.rbdalgos.NDoF + self.g = gravity + self.funcs = {} + + def set_frame_velocity_representation( + self, representation: Representations + ) -> None: + """Sets the representation of the velocity of the frames + + Args: + representation (Representations): The representation of the velocity + """ + self.rbdalgos.set_frame_velocity_representation(representation) + + def mass_matrix( + self, base_transform: torch.Tensor, joint_positions: torch.Tensor + ) -> torch.Tensor: + """Returns the Mass Matrix functions computed the CRBA + + Args: + base_transform (torch.Tensor): The batch of homogenous transforms from base to world frame + joint_positions (torch.Tensor): The batch of joints position + + Returns: + M (torch.Tensor): The batch Mass Matrix + """ + + return self.mass_matrix_fun()(base_transform, joint_positions) + + def mass_matrix_fun(self): + """Returns the Mass Matrix functions computed the CRBA as a pytorch function + + Returns: + M (pytorch function): Mass Matrix + """ + + if self.funcs.get("mass_matrix") is not None: + return self.funcs["mass_matrix"] + print("[INFO] Compiling mass matrix function") + + def fun(base_transform, joint_positions): + [M, _] = self.rbdalgos.crba(base_transform, joint_positions) + return M.array + + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs["mass_matrix"] = jax2torch(jit_vmapped_fun) + return self.funcs["mass_matrix"] + + def centroidal_momentum_matrix( + self, base_transform: torch.Tensor, joint_positions: torch.Tensor + ) -> torch.Tensor: + """Returns the Centroidal Momentum Matrix functions computed the CRBA + + Args: + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position + + Returns: + Jcc (torch.Tensor): Centroidal Momentum matrix + """ + + return self.centroidal_momentum_matrix_fun()(base_transform, joint_positions) + + def centroidal_momentum_matrix_fun(self): + """Returns the Centroidal Momentum Matrix functions computed the CRBA as a pytorch function + + Returns: + Jcc (pytorch function): Centroidal Momentum matrix + """ + + if self.funcs.get("centroidal_momentum_matrix") is not None: + return self.funcs["centroidal_momentum_matrix"] + print("[INFO] Compiling centroidal momentum matrix function") + + def fun(base_transform, joint_positions): + [_, Jcm] = self.rbdalgos.crba(base_transform, joint_positions) + return Jcm.array + + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs["centroidal_momentum_matrix"] = jax2torch(jit_vmapped_fun) + return self.funcs["centroidal_momentum_matrix"] + + def relative_jacobian( + self, frame: str, joint_positions: torch.Tensor + ) -> torch.Tensor: + + return self.relative_jacobian_fun(frame)(joint_positions) + + def relative_jacobian_fun(self, frame: str): + """Returns the Jacobian between the root link and a specified frame frames as a pytorch function + + Args: + frame (str): The tip of the chain + + Returns: + J (pytorch function): The Jacobian between the root and the frame + """ + + if self.funcs.get(f"relative_jacobian_{frame}") is not None: + return self.funcs[f"relative_jacobian_{frame}"] + print(f"[INFO] Compiling relative jacobian function for {frame} frame") + + def fun(joint_positions): + return self.rbdalgos.relative_jacobian(frame, joint_positions).array + + vmapped_fun = jax.vmap(fun) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs[f"relative_jacobian_{frame}"] = jax2torch(jit_vmapped_fun) + return self.funcs[f"relative_jacobian_{frame}"] + + def jacobian_dot( + self, + frame: str, + base_transform: torch.Tensor, + joint_positions: torch.Tensor, + base_velocity: torch.Tensor, + joint_velocities: torch.Tensor, + ) -> torch.Tensor: + """Returns the Jacobian derivative relative to the specified frame + + Args: + frame (str): The frame to which the jacobian will be computed + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position + base_velocity (torch.Tensor): The base velocity + joint_velocities (torch.Tensor): The joint velocities + + Returns: + Jdot (torch.Tensor): The Jacobian derivative relative to the frame + """ + + return self.jacobian_dot_fun(frame)( + base_transform, joint_positions, base_velocity, joint_velocities + ) + + def jacobian_dot_fun( + self, + frame: str, + ): + """Returns the Jacobian derivative between the root and the specified frame as a pytorch function + + Args: + frame (str): The frame to which the jacobian will be computed + + Returns: + Jdot (pytorch function): The Jacobian derivative between the root and the frame + """ + + if self.funcs.get(f"jacobian_dot_{frame}") is not None: + return self.funcs[f"jacobian_dot_{frame}"] + print(f"[INFO] Compiling jacobian dot function for {frame} frame") + + def fun(base_transform, joint_positions, base_velocity, joint_velocities): + return self.rbdalgos.jacobian_dot( + frame, base_transform, joint_positions, base_velocity, joint_velocities + ).array + + vmapped_fun = jax.vmap(fun, in_axes=(0, 0, 0, 0)) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs[f"jacobian_dot_{frame}"] = jax2torch(jit_vmapped_fun) + return self.funcs[f"jacobian_dot_{frame}"] + + def forward_kinematics( + self, frame: str, base_transform: torch.Tensor, joint_positions: torch.Tensor + ) -> torch.Tensor: + """Computes the forward kinematics between the root and the specified frame + + Args: + frame (str): The frame to which the fk will be computed + + Returns: + H (torch.Tensor): The fk represented as Homogenous transformation matrix + """ + + return self.forward_kinematics_fun(frame)(base_transform, joint_positions) + + def forward_kinematics_fun(self, frame: str): + """Computes the forward kinematics between the root and the specified frame as a pytorch function + + Args: + frame (str): The frame to which the fk will be computed + + Returns: + H (pytorch function): The fk represented as Homogenous transformation matrix + """ + + if self.funcs.get(f"forward_kinematics_{frame}") is not None: + return self.funcs[f"forward_kinematics_{frame}"] + print(f"[INFO] Compiling forward kinematics function for {frame} frame") + + def fun(base_transform, joint_positions): + return self.rbdalgos.forward_kinematics( + frame, base_transform, joint_positions + ).array + + vmapped_fun = jax.vmap(fun) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs[f"forward_kinematics_{frame}"] = jax2torch(jit_vmapped_fun) + return self.funcs[f"forward_kinematics_{frame}"] + + def jacobian( + self, frame: str, base_transform: torch.Tensor, joint_positions: torch.Tensor + ) -> torch.Tensor: + """Returns the Jacobian relative to the specified frame + + Args: + frame (str): The frame to which the jacobian will be computed + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position + + Returns: + J (torch.Tensor): The Jacobian between the root and the frame + """ + return self.jacobian_fun(frame)(base_transform, joint_positions) + + def jacobian_fun(self, frame: str): + """Returns the Jacobian relative to the specified frame as a pytorch function + + Args: + frame (str): The frame to which the jacobian will be computed + + Returns: + J (pytorch function): The Jacobian relative to the frame + """ + if self.funcs.get(f"jacobian_{frame}") is not None: + return self.funcs[f"jacobian_{frame}"] + print(f"[INFO] Compiling jacobian function for {frame} frame") + + def fun(base_transform, joint_positions): + return self.rbdalgos.jacobian(frame, base_transform, joint_positions).array + + vmapped_fun = jax.vmap(fun) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs[f"jacobian_{frame}"] = jax2torch(jit_vmapped_fun) + return self.funcs[f"jacobian_{frame}"] + + def bias_force( + self, + base_transform: torch.Tensor, + joint_positions: torch.Tensor, + base_velocity: torch.Tensor, + joint_velocities: torch.Tensor, + ) -> jnp.array: + """Returns the bias force of the floating-base dynamics equation, + using a reduced RNEA (no acceleration and external forces) + + Args: + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position + base_velocity (torch.Tensor): The base velocity + joint_velocities (torch.Tensor): The joints velocity + + Returns: + h (torch.Tensor): the bias force + """ + return self.bias_force_fun()( + base_transform, joint_positions, base_velocity, joint_velocities + ) + + def bias_force_fun(self): + """Returns the bias force of the floating-base dynamics equation as a pytorch function + + Returns: + h (pytorch function): the bias force + """ + if self.funcs.get("bias_force") is not None: + return self.funcs["bias_force"] + print("[INFO] Compiling bias force function") + + def fun(base_transform, joint_positions, base_velocity, joint_velocities): + return self.rbdalgos.rnea( + base_transform, joint_positions, base_velocity, joint_velocities, self.g + ).array.squeeze() + + vmapped_fun = jax.vmap(fun, in_axes=(0, 0, 0, 0)) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs["bias_force"] = jax2torch(jit_vmapped_fun) + return self.funcs["bias_force"] + + def coriolis_term( + self, + base_transform: torch.Tensor, + joint_positions: torch.Tensor, + base_velocity: torch.Tensor, + joint_velocities: torch.Tensor, + ) -> torch.Tensor: + """Returns the coriolis term of the floating-base dynamics equation, + using a reduced RNEA (no acceleration and external forces) + + Args: + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position + base_velocity (torch.Tensor): The base velocity + joint_velocities (torch.Tensor): The joints velocity + + Returns: + C (torch.Tensor): the Coriolis term + """ + return self.coriolis_term_fun()( + base_transform, joint_positions, base_velocity, joint_velocities + ) + + def coriolis_term_fun(self): + """Returns the coriolis term of the floating-base dynamics equation as a pytorch function + + Returns: + C (pytorch function): the Coriolis term + """ + if self.funcs.get("coriolis_term") is not None: + return self.funcs["coriolis_term"] + print("[INFO] Compiling coriolis term function") + + def fun(base_transform, joint_positions, base_velocity, joint_velocities): + return self.rbdalgos.rnea( + base_transform, + joint_positions, + base_velocity.reshape(6, 1), + joint_velocities, + np.zeros(6), + ).array.squeeze() + + vmapped_fun = jax.vmap(fun, in_axes=(0, 0, 0, 0)) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs["coriolis_term"] = jax2torch(jit_vmapped_fun) + return self.funcs["coriolis_term"] + + def gravity_term( + self, base_transform: jnp.array, joint_positions: jnp.array + ) -> jnp.array: + """Returns the gravity term of the floating-base dynamics equation, + using a reduced RNEA (no acceleration and external forces) + + Args: + base_transform (jnp.array): The homogenous transform from base to world frame + joint_positions (jnp.array): The joints position + + Returns: + G (jnp.array): the gravity term + """ + return self.rbdalgos.rnea( + base_transform, + joint_positions, + np.zeros(6).reshape(6, 1), + np.zeros(self.NDoF), + self.g, + ).array.squeeze() + + def CoM_position( + self, base_transform: torch.Tensor, joint_positions: torch.Tensor + ) -> torch.Tensor: + """Returns the CoM positon + + Args: + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position + + Returns: + CoM (torch.Tensor): The CoM position + """ + return self.CoM_position_fun()(base_transform, joint_positions) + + def CoM_position_fun(self): + """Returns the CoM positon as a pytorch function + + Returns: + CoM (pytorch function): The CoM position + """ + if self.funcs.get("CoM_position") is not None: + return self.funcs["CoM_position"] + print("[INFO] Compiling CoM position function") + + def fun(base_transform, joint_positions): + return self.rbdalgos.CoM_position(base_transform, joint_positions).array + + vmapped_fun = jax.vmap(fun) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs["CoM_position"] = jax2torch(jit_vmapped_fun) + return self.funcs["CoM_position"] + + def get_total_mass(self) -> float: + """Returns the total mass of the robot + + Returns: + mass: The total mass + """ + return self.rbdalgos.get_total_mass() From 5018c9666e1a447d8c78b50eb02384890195696e Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 12:46:39 +0200 Subject: [PATCH 02/16] Add in init --- src/adam/pytorch/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/adam/pytorch/__init__.py b/src/adam/pytorch/__init__.py index 8a3a50b6..7a33bea1 100644 --- a/src/adam/pytorch/__init__.py +++ b/src/adam/pytorch/__init__.py @@ -3,4 +3,5 @@ # GNU Lesser General Public License v2.1 or any later version. from .computations import KinDynComputations +from .computation_batch import KinDynComputationsBatch from .torch_like import TorchLike From 27e08b013793af32116995309e88cf76bc8e616d Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 12:46:57 +0200 Subject: [PATCH 03/16] Add dependency --- setup.cfg | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/setup.cfg b/setup.cfg index e93a0967..b77e5411 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,6 +44,9 @@ casadi = casadi pytorch = torch + jax + jaxlib + jax2torch test = jax jaxlib @@ -54,6 +57,8 @@ test = icub-models black gitpython + jax2torch + conversions = idyntree all = @@ -61,6 +66,7 @@ all = jaxlib casadi torch + jax2torch [tool:pytest] addopts = --capture=no --verbose From 003f99f2e055978b3aadc23b728a79c65c75a6a0 Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 12:47:09 +0200 Subject: [PATCH 04/16] Add tests on batching --- tests/pytorch_batch/test_pytorch_batch.py | 199 ++++++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 tests/pytorch_batch/test_pytorch_batch.py diff --git a/tests/pytorch_batch/test_pytorch_batch.py b/tests/pytorch_batch/test_pytorch_batch.py new file mode 100644 index 00000000..46fd58b2 --- /dev/null +++ b/tests/pytorch_batch/test_pytorch_batch.py @@ -0,0 +1,199 @@ +import logging + +import icub_models +import idyntree.swig as idyntree +import jax.numpy as jnp +import numpy as np +import pytest +from jax import config + +import adam +from adam.geometry import utils +from adam.pytorch import KinDynComputationsBatch +from adam.numpy import KinDynComputations +import torch + +np.random.seed(42) +config.update("jax_enable_x64", True) + +model_path = str(icub_models.get_model_file("iCubGazeboV2_5")) + +joints_name_list = [ + "torso_pitch", + "torso_roll", + "torso_yaw", + "l_shoulder_pitch", + "l_shoulder_roll", + "l_shoulder_yaw", + "l_elbow", + "r_shoulder_pitch", + "r_shoulder_roll", + "r_shoulder_yaw", + "r_elbow", + "l_hip_pitch", + "l_hip_roll", + "l_hip_yaw", + "l_knee", + "l_ankle_pitch", + "l_ankle_roll", + "r_hip_pitch", + "r_hip_roll", + "r_hip_yaw", + "r_knee", + "r_ankle_pitch", + "r_ankle_roll", +] + + +comp = KinDynComputationsBatch(model_path, joints_name_list) +comp.set_frame_velocity_representation(adam.Representations.MIXED_REPRESENTATION) + +comp_np = KinDynComputations(model_path, joints_name_list) +comp_np.set_frame_velocity_representation(adam.Representations.MIXED_REPRESENTATION) + +n_dofs = len(joints_name_list) +# base pose quantities +xyz = (np.random.rand(3) - 0.5) * 5 +rpy = (np.random.rand(3) - 0.5) * 5 +base_vel = (np.random.rand(6) - 0.5) * 5 +# joints quantitites +joints_val = (np.random.rand(n_dofs) - 0.5) * 5 +joints_dot_val = (np.random.rand(n_dofs) - 0.5) * 5 + +g = np.array([0, 0, -9.80665]) +H_b = utils.H_from_Pos_RPY(xyz, rpy) +n_samples = 10 + +H_b_batch = torch.tile(torch.tensor(H_b), (n_samples, 1, 1)).requires_grad_() +joints_val_batch = torch.tile(torch.tensor(joints_val), (n_samples, 1)).requires_grad_() +base_vel_batch = torch.tile(torch.tensor(base_vel), (n_samples, 1)).requires_grad_() +joints_dot_val_batch = torch.tile( + torch.tensor(joints_dot_val), (n_samples, 1) +).requires_grad_() + + +# Check if the quantities are the correct testing against the numpy implementation +# Check if the dimensions are correct (batch dimension) +# Check if the gradient is computable + + +def test_mass_matrix(): + mass_matrix = comp.mass_matrix(H_b_batch, joints_val_batch) + mass_matrix_np = comp_np.mass_matrix(H_b, joints_val) + assert np.allclose(mass_matrix[0].detach().numpy(), mass_matrix_np) + assert mass_matrix.shape == (n_samples, n_dofs + 6, n_dofs + 6) + # check if the gradient is computable + mass_matrix.sum().backward() + return True + # assert torch.autograd.gradcheck( + # comp.mass_matrix, (H_b_batch, joints_val_batch), eps=1e-6, atol=1e-4 + # ) + # return True + + +def test_centroidal_momentum_matrix(): + centroidal_momentum_matrix = comp.centroidal_momentum_matrix( + H_b_batch, joints_val_batch + ) + centroidal_momentum_matrix_np = comp_np.centroidal_momentum_matrix(H_b, joints_val) + assert np.allclose( + centroidal_momentum_matrix[0].detach().numpy(), centroidal_momentum_matrix_np + ) + assert centroidal_momentum_matrix.shape == (n_samples, 6, n_dofs + 6) + centroidal_momentum_matrix.sum().backward() + return True + + +def test_relative_jacobian(): + frame = "l_sole" + relative_jacobian = comp.relative_jacobian(frame, joints_val_batch) + assert np.allclose( + relative_jacobian[0].detach().numpy(), + comp_np.relative_jacobian(frame, joints_val), + ) + assert relative_jacobian.shape == (n_samples, 6, n_dofs) + relative_jacobian.sum().backward() + return True + + +def test_jacobian_dot(): + frame = "l_sole" + jacobian_dot = comp.jacobian_dot( + frame, H_b_batch, joints_val_batch, base_vel_batch, joints_dot_val_batch + ) + assert np.allclose( + jacobian_dot[0].detach().numpy(), + comp_np.jacobian_dot(frame, H_b, joints_val, base_vel, joints_dot_val), + ) + assert jacobian_dot.shape == (n_samples, 6, n_dofs + 6) + jacobian_dot.sum().backward() + return True + + +def test_forward_kineamtics(): + frame = "l_sole" + forward_kinematics = comp.forward_kinematics(frame, H_b_batch, joints_val_batch) + assert np.allclose( + forward_kinematics[0].detach().numpy(), + comp_np.forward_kinematics(frame, H_b, joints_val), + ) + assert forward_kinematics.shape == (n_samples, 4, 4) + forward_kinematics.sum().backward() + return True + + +def test_jacobian(): + frame = "l_sole" + jacobian = comp.jacobian(frame, H_b_batch, joints_val_batch) + assert np.allclose( + jacobian[0].detach().numpy(), comp_np.jacobian(frame, H_b, joints_val) + ) + assert jacobian.shape == (n_samples, 6, n_dofs + 6) + jacobian.sum().backward() + return True + + +def test_bias_force(): + bias_force = comp.bias_force( + H_b_batch, joints_val_batch, base_vel_batch, joints_dot_val_batch + ) + assert np.allclose( + bias_force[0].detach().numpy(), + comp_np.bias_force(H_b, joints_val, base_vel, joints_dot_val), + ) + assert bias_force.shape == (n_samples, n_dofs + 6) + bias_force.sum().backward() + return True + + +def test_coriolis_term(): + coriolis_term = comp.coriolis_term( + H_b_batch, joints_val_batch, base_vel_batch, joints_dot_val_batch + ) + assert np.allclose( + coriolis_term[0].detach().numpy(), + comp_np.coriolis_term(H_b, joints_val, base_vel, joints_dot_val), + ) + assert coriolis_term.shape == (n_samples, n_dofs + 6) + coriolis_term.sum().backward() + return True + + +def test_gravity_term(): + gravity_term = comp.gravity_term(H_b_batch, joints_val_batch) + assert np.allclose( + gravity_term[0].detach().numpy(), comp_np.gravity_term(H_b, joints_val) + ) + assert gravity_term.shape == (n_samples, n_dofs + 6) + gravity_term.sum().backward() + return True + + +def test_CoM_position(): + CoM_position = comp.CoM_position(H_b_batch, joints_val_batch) + assert np.allclose( + CoM_position[0].detach().numpy(), comp_np.CoM_position(H_b, joints_val) + ) + assert CoM_position.shape == (n_samples, 3) + CoM_position.sum().backward() + return True From 46f0265014bcc535d98ec9782614a090677090ef Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 14:08:27 +0200 Subject: [PATCH 05/16] Adding forgotten gravity term function --- src/adam/pytorch/computation_batch.py | 40 +++++++++++++++++++-------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/src/adam/pytorch/computation_batch.py b/src/adam/pytorch/computation_batch.py index 1e7753a9..4939cd7c 100644 --- a/src/adam/pytorch/computation_batch.py +++ b/src/adam/pytorch/computation_batch.py @@ -365,25 +365,43 @@ def fun(base_transform, joint_positions, base_velocity, joint_velocities): return self.funcs["coriolis_term"] def gravity_term( - self, base_transform: jnp.array, joint_positions: jnp.array - ) -> jnp.array: + self, base_transform: torch.Tensor, joint_positions: torch.Tensor + ) -> torch.Tensor: """Returns the gravity term of the floating-base dynamics equation, using a reduced RNEA (no acceleration and external forces) Args: - base_transform (jnp.array): The homogenous transform from base to world frame - joint_positions (jnp.array): The joints position + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position Returns: G (jnp.array): the gravity term """ - return self.rbdalgos.rnea( - base_transform, - joint_positions, - np.zeros(6).reshape(6, 1), - np.zeros(self.NDoF), - self.g, - ).array.squeeze() + return self.gravity_term_fun()(base_transform, joint_positions) + + def gravity_term_fun(self): + """Returns the gravity term of the floating-base dynamics equation as a pytorch function + + Returns: + G (pytorch function): the gravity term + """ + if self.funcs.get("gravity_term") is not None: + return self.funcs["gravity_term"] + print("[INFO] Compiling gravity term function") + + def fun(base_transform, joint_positions): + return self.rbdalgos.rnea( + base_transform, + joint_positions, + np.zeros(6).reshape(6, 1), + np.zeros(self.NDoF), + self.g, + ).array.squeeze() + + vmapped_fun = jax.vmap(fun) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs["gravity_term"] = jax2torch(jit_vmapped_fun) + return self.funcs["gravity_term"] def CoM_position( self, base_transform: torch.Tensor, joint_positions: torch.Tensor From af1e266a7c81f418f0068a4e5811a8a5c94bf5d9 Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 14:17:43 +0200 Subject: [PATCH 06/16] Remove comment --- tests/pytorch_batch/test_pytorch_batch.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/pytorch_batch/test_pytorch_batch.py b/tests/pytorch_batch/test_pytorch_batch.py index 46fd58b2..ec1287dc 100644 --- a/tests/pytorch_batch/test_pytorch_batch.py +++ b/tests/pytorch_batch/test_pytorch_batch.py @@ -82,13 +82,8 @@ def test_mass_matrix(): mass_matrix_np = comp_np.mass_matrix(H_b, joints_val) assert np.allclose(mass_matrix[0].detach().numpy(), mass_matrix_np) assert mass_matrix.shape == (n_samples, n_dofs + 6, n_dofs + 6) - # check if the gradient is computable mass_matrix.sum().backward() return True - # assert torch.autograd.gradcheck( - # comp.mass_matrix, (H_b_batch, joints_val_batch), eps=1e-6, atol=1e-4 - # ) - # return True def test_centroidal_momentum_matrix(): From 27d33a9b3515ba12dab12a2342947b75d45e51a4 Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 14:33:44 +0200 Subject: [PATCH 07/16] Remove return True --- tests/pytorch_batch/test_pytorch_batch.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/pytorch_batch/test_pytorch_batch.py b/tests/pytorch_batch/test_pytorch_batch.py index ec1287dc..44b3a33d 100644 --- a/tests/pytorch_batch/test_pytorch_batch.py +++ b/tests/pytorch_batch/test_pytorch_batch.py @@ -83,7 +83,6 @@ def test_mass_matrix(): assert np.allclose(mass_matrix[0].detach().numpy(), mass_matrix_np) assert mass_matrix.shape == (n_samples, n_dofs + 6, n_dofs + 6) mass_matrix.sum().backward() - return True def test_centroidal_momentum_matrix(): @@ -96,7 +95,6 @@ def test_centroidal_momentum_matrix(): ) assert centroidal_momentum_matrix.shape == (n_samples, 6, n_dofs + 6) centroidal_momentum_matrix.sum().backward() - return True def test_relative_jacobian(): @@ -108,7 +106,6 @@ def test_relative_jacobian(): ) assert relative_jacobian.shape == (n_samples, 6, n_dofs) relative_jacobian.sum().backward() - return True def test_jacobian_dot(): @@ -122,7 +119,6 @@ def test_jacobian_dot(): ) assert jacobian_dot.shape == (n_samples, 6, n_dofs + 6) jacobian_dot.sum().backward() - return True def test_forward_kineamtics(): @@ -134,7 +130,6 @@ def test_forward_kineamtics(): ) assert forward_kinematics.shape == (n_samples, 4, 4) forward_kinematics.sum().backward() - return True def test_jacobian(): @@ -145,7 +140,6 @@ def test_jacobian(): ) assert jacobian.shape == (n_samples, 6, n_dofs + 6) jacobian.sum().backward() - return True def test_bias_force(): @@ -158,7 +152,6 @@ def test_bias_force(): ) assert bias_force.shape == (n_samples, n_dofs + 6) bias_force.sum().backward() - return True def test_coriolis_term(): @@ -171,7 +164,6 @@ def test_coriolis_term(): ) assert coriolis_term.shape == (n_samples, n_dofs + 6) coriolis_term.sum().backward() - return True def test_gravity_term(): @@ -181,7 +173,6 @@ def test_gravity_term(): ) assert gravity_term.shape == (n_samples, n_dofs + 6) gravity_term.sum().backward() - return True def test_CoM_position(): @@ -191,4 +182,3 @@ def test_CoM_position(): ) assert CoM_position.shape == (n_samples, 3) CoM_position.sum().backward() - return True From 9c965bc716ddef55de739c202b0f66f3d363ad50 Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 15:00:38 +0200 Subject: [PATCH 08/16] Update ci_env with jax2torch dep. Waiting for conda recipe --- ci_env.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ci_env.yml b/ci_env.yml index 1a858b48..e97eff79 100644 --- a/ci_env.yml +++ b/ci_env.yml @@ -16,6 +16,7 @@ dependencies: - pytest-repeat - icub-models - idyntree >=11.0.0 - - gitpython + - gitpython - jax - pytorch + - jax2torch From 235f054f0c2368fbd2922f2148fe0f8d8176167e Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 17:03:02 +0200 Subject: [PATCH 09/16] Add pytorch batched usage in readme --- README.md | 60 ++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index a92a5c39..0de706c0 100644 --- a/README.md +++ b/README.md @@ -17,18 +17,20 @@ **adam** is based on Roy Featherstone's Rigid Body Dynamics Algorithms. ### Table of contents - - [🐍 Dependencies](#-dependencies) - - [πŸ’Ύ Installation](#-installation) - - [🐍 Installation with pip](#-installation-with-pip) - - [πŸ“¦ Installation with conda](#-installation-with-conda) - - [Installation from conda-forge package](#installation-from-conda-forge-package) - - [πŸ”¨ Installation from repo](#-installation-from-repo) - - [πŸš€ Usage](#-usage) - - [Jax interface](#jax-interface) - - [CasADi interface](#casadi-interface) - - [PyTorch interface](#pytorch-interface) - - [πŸ¦Έβ€β™‚οΈ Contributing](#️-contributing) - - [Todo](#todo) + +- [🐍 Dependencies](#-dependencies) +- [πŸ’Ύ Installation](#-installation) + - [🐍 Installation with pip](#-installation-with-pip) + - [πŸ“¦ Installation with conda](#-installation-with-conda) + - [Installation from conda-forge package](#installation-from-conda-forge-package) + - [πŸ”¨ Installation from repo](#-installation-from-repo) +- [πŸš€ Usage](#-usage) + - [Jax interface](#jax-interface) + - [CasADi interface](#casadi-interface) + - [PyTorch interface](#pytorch-interface) + - [PyTorch Batched interface](#pytorch-batched-interface) +- [πŸ¦Έβ€β™‚οΈ Contributing](#️-contributing) +- [Todo](#todo) ## 🐍 Dependencies @@ -284,6 +286,40 @@ M = kinDyn.mass_matrix(w_H_b, joints) print(M) ``` +### PyTorch Batched interface + +```python +import adam +from adam.pytorch import KinDynComputationsBatch +import icub_models + +# if you want to icub-models +model_path = icub_models.get_model_file("iCubGazeboV2_5") +# The joint list +joints_name_list = [ + 'torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch', + 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch', + 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch', 'l_hip_roll', + 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll', 'r_hip_pitch', + 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll' +] + +kinDyn = KinDynComputationsBatch(model_path, joints_name_list) +# choose the representation you want to use the body fixed representation +kinDyn.set_frame_velocity_representation(adam.Representations.BODY_FIXED_REPRESENTATION) +# or, if you want to use the mixed representation (that is the default) +kinDyn.set_frame_velocity_representation(adam.Representations.MIXED_REPRESENTATION) +w_H_b = np.eye(4) +joints = np.ones(len(joints_name_list)) + +num_samples = 1024 +w_H_b_batch = torch.tensor(np.tile(w_H_b, (num_samples, 1, 1)), dtype=torch.float32) +joints_batch = torch.tensor(np.tile(joints, (num_samples, 1)), dtype=torch.float32) + +M = kinDyn.mass_matrix(w_H_b_batch, joints_batch) +w_H_f = kinDyn.forward_kinematics('frame_name', w_H_b_batch, joints_batch) +``` + ## πŸ¦Έβ€β™‚οΈ Contributing **adam** is an open-source project. Contributions are very welcome! From e5f3ca3ce9508d4eb7cea88ab40b72dba652984c Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 17:10:47 +0200 Subject: [PATCH 10/16] Add additional notes --- README.md | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 0de706c0..4ec1916c 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ Other requisites are: - `casadi` - `pytorch` - `numpy` +- `jax2torch` They will be installed in the installation step! @@ -116,6 +117,9 @@ mamba create -n adamenv -c conda-forge adam-robotics If you want to use `jax` or `pytorch`, just install the corresponding package as well. +> [!NOTE] +> Check also the conda JAX installation guide [here](https://jax.readthedocs.io/en/latest/installation.html#conda-community-supported) + ### πŸ”¨ Installation from repo Install in a conda environment the required dependencies: @@ -135,13 +139,13 @@ Install in a conda environment the required dependencies: - **PyTorch** interface dependencies: ```bash - mamba create -n adamenv -c conda-forge pytorch numpy lxml prettytable matplotlib urdfdom-py + mamba create -n adamenv -c conda-forge pytorch numpy lxml prettytable matplotlib urdfdom-py jax2torch ``` - **ALL** interfaces dependencies: ```bash - mamba create -n adamenv -c conda-forge jax casadi pytorch numpy lxml prettytable matplotlib urdfdom-py + mamba create -n adamenv -c conda-forge jax casadi pytorch numpy lxml prettytable matplotlib urdfdom-py jax2torch ``` Activate the environment, clone the repo and install the library: @@ -156,10 +160,13 @@ pip install --no-deps . ## πŸš€ Usage The following are small snippets of the use of **adam**. More examples are arriving! -Have also a look at te `tests` folder. +Have also a look at the `tests` folder. ### Jax interface +> [!NOTE] +> Check also the Jax installation guide [here](https://jax.readthedocs.io/en/latest/installation.html#) + ```python import adam from adam.jax import KinDynComputations From 9e83e1da037aca18c0e5884f39c87a1f9f559dfa Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 17:26:37 +0200 Subject: [PATCH 11/16] Other notes --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4ec1916c..81af4253 100644 --- a/README.md +++ b/README.md @@ -214,8 +214,8 @@ jitted_vmapped_frame_fk = jit(vmapped_frame_fk) # and called on a batch of data joints_batch = jnp.tile(joints, (1024, 1)) w_H_b_batch = jnp.tile(w_H_b, (1024, 1, 1)) - w_H_f_batch = jitted_vmapped_frame_fk(w_H_b_batch, joints_batch) +# Note that the first call of the jitted function can be slow, since JAX needs to compile the function. Than it will be fast. ``` @@ -325,6 +325,7 @@ joints_batch = torch.tensor(np.tile(joints, (num_samples, 1)), dtype=torch.float M = kinDyn.mass_matrix(w_H_b_batch, joints_batch) w_H_f = kinDyn.forward_kinematics('frame_name', w_H_b_batch, joints_batch) +# Note that the first call of the jitted function can be slow, since JAX needs to compile the function. Than it will be fast. ``` ## πŸ¦Έβ€β™‚οΈ Contributing From b20b15dd33c10490c0295cfe6420affe687566ba Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 17:30:13 +0200 Subject: [PATCH 12/16] Move notes into boxes --- README.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 81af4253..81be897b 100644 --- a/README.md +++ b/README.md @@ -215,10 +215,15 @@ jitted_vmapped_frame_fk = jit(vmapped_frame_fk) joints_batch = jnp.tile(joints, (1024, 1)) w_H_b_batch = jnp.tile(w_H_b, (1024, 1, 1)) w_H_f_batch = jitted_vmapped_frame_fk(w_H_b_batch, joints_batch) -# Note that the first call of the jitted function can be slow, since JAX needs to compile the function. Than it will be fast. + ``` +> [!NOTE] +> The first call of the jitted function can be slow, since JAX needs to compile the function. Than it will be fast! + +```python + ### CasADi interface ```python @@ -260,7 +265,6 @@ joints = cs.MX.sym('joints', len(joints_name_list)) M = kinDyn.mass_matrix_fun() print(M(w_H_b, joints)) - ``` ### PyTorch interface @@ -295,6 +299,9 @@ print(M) ### PyTorch Batched interface +> [!NOTE] +> When using this interface, note that the first call of the jitted function can be slow, since JAX needs to compile the function. Than it will be fast! + ```python import adam from adam.pytorch import KinDynComputationsBatch @@ -325,7 +332,6 @@ joints_batch = torch.tensor(np.tile(joints, (num_samples, 1)), dtype=torch.float M = kinDyn.mass_matrix(w_H_b_batch, joints_batch) w_H_f = kinDyn.forward_kinematics('frame_name', w_H_b_batch, joints_batch) -# Note that the first call of the jitted function can be slow, since JAX needs to compile the function. Than it will be fast. ``` ## πŸ¦Έβ€β™‚οΈ Contributing From 2b0c32f1248545d31e59413a641af494eac7ebd0 Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 17:35:22 +0200 Subject: [PATCH 13/16] Fix typo --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 81be897b..548cc642 100644 --- a/README.md +++ b/README.md @@ -222,8 +222,6 @@ w_H_f_batch = jitted_vmapped_frame_fk(w_H_b_batch, joints_batch) > [!NOTE] > The first call of the jitted function can be slow, since JAX needs to compile the function. Than it will be fast! -```python - ### CasADi interface ```python From 4fc6e05a23427d004d2f5f42477ed69874600d81 Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 18:17:01 +0200 Subject: [PATCH 14/16] Fix typo --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 548cc642..83b6486d 100644 --- a/README.md +++ b/README.md @@ -220,7 +220,7 @@ w_H_f_batch = jitted_vmapped_frame_fk(w_H_b_batch, joints_batch) ``` > [!NOTE] -> The first call of the jitted function can be slow, since JAX needs to compile the function. Than it will be fast! +> The first call of the jitted function can be slow, since JAX needs to compile the function. Then it will be faster! ### CasADi interface @@ -298,7 +298,7 @@ print(M) ### PyTorch Batched interface > [!NOTE] -> When using this interface, note that the first call of the jitted function can be slow, since JAX needs to compile the function. Than it will be fast! +> When using this interface, note that the first call of the jitted function can be slow, since JAX needs to compile the function. Then it will be faster! ```python import adam From 18f07e6f3802d608071da484b1305f7fa9b2a4e5 Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 18:19:12 +0200 Subject: [PATCH 15/16] Sort imports --- src/adam/pytorch/computation_batch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/adam/pytorch/computation_batch.py b/src/adam/pytorch/computation_batch.py index 4939cd7c..aef57056 100644 --- a/src/adam/pytorch/computation_batch.py +++ b/src/adam/pytorch/computation_batch.py @@ -2,16 +2,16 @@ # This software may be modified and distributed under the terms of the # GNU Lesser General Public License v2.1 or any later version. +import jax import jax.numpy as jnp import numpy as np +import torch +from jax2torch import jax2torch from adam.core.constants import Representations from adam.core.rbd_algorithms import RBDAlgorithms from adam.jax.jax_like import SpatialMath from adam.model import Model, URDFModelFactory -from jax2torch import jax2torch -import torch -import jax class KinDynComputationsBatch: From 9d373165557392c77cbaf79a5599409ab37ebe05 Mon Sep 17 00:00:00 2001 From: giulero Date: Thu, 27 Jun 2024 18:37:55 +0200 Subject: [PATCH 16/16] Align in_axis in functions --- src/adam/pytorch/computation_batch.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/adam/pytorch/computation_batch.py b/src/adam/pytorch/computation_batch.py index aef57056..b5e96a10 100644 --- a/src/adam/pytorch/computation_batch.py +++ b/src/adam/pytorch/computation_batch.py @@ -16,7 +16,7 @@ class KinDynComputationsBatch: """This is a small class that retrieves robot quantities using Jax for Floating Base systems. - These functions are vmapped and jit compiled and passed to jax2torch to convert them to torch functions. + These functions are vmapped and jit compiled and passed to jax2torch to convert them to PyTorch functions. """ def __init__( @@ -233,7 +233,7 @@ def fun(base_transform, joint_positions): frame, base_transform, joint_positions ).array - vmapped_fun = jax.vmap(fun) + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) jit_vmapped_fun = jax.jit(vmapped_fun) self.funcs[f"forward_kinematics_{frame}"] = jax2torch(jit_vmapped_fun) return self.funcs[f"forward_kinematics_{frame}"] @@ -269,7 +269,7 @@ def jacobian_fun(self, frame: str): def fun(base_transform, joint_positions): return self.rbdalgos.jacobian(frame, base_transform, joint_positions).array - vmapped_fun = jax.vmap(fun) + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) jit_vmapped_fun = jax.jit(vmapped_fun) self.funcs[f"jacobian_{frame}"] = jax2torch(jit_vmapped_fun) return self.funcs[f"jacobian_{frame}"] @@ -398,7 +398,7 @@ def fun(base_transform, joint_positions): self.g, ).array.squeeze() - vmapped_fun = jax.vmap(fun) + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) jit_vmapped_fun = jax.jit(vmapped_fun) self.funcs["gravity_term"] = jax2torch(jit_vmapped_fun) return self.funcs["gravity_term"] @@ -430,7 +430,7 @@ def CoM_position_fun(self): def fun(base_transform, joint_positions): return self.rbdalgos.CoM_position(base_transform, joint_positions).array - vmapped_fun = jax.vmap(fun) + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) jit_vmapped_fun = jax.jit(vmapped_fun) self.funcs["CoM_position"] = jax2torch(jit_vmapped_fun) return self.funcs["CoM_position"]