diff --git a/src/adam/pytorch/computation_batch.py b/src/adam/pytorch/computation_batch.py index aef57056..b5e96a10 100644 --- a/src/adam/pytorch/computation_batch.py +++ b/src/adam/pytorch/computation_batch.py @@ -16,7 +16,7 @@ class KinDynComputationsBatch: """This is a small class that retrieves robot quantities using Jax for Floating Base systems. - These functions are vmapped and jit compiled and passed to jax2torch to convert them to torch functions. + These functions are vmapped and jit compiled and passed to jax2torch to convert them to PyTorch functions. """ def __init__( @@ -233,7 +233,7 @@ def fun(base_transform, joint_positions): frame, base_transform, joint_positions ).array - vmapped_fun = jax.vmap(fun) + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) jit_vmapped_fun = jax.jit(vmapped_fun) self.funcs[f"forward_kinematics_{frame}"] = jax2torch(jit_vmapped_fun) return self.funcs[f"forward_kinematics_{frame}"] @@ -269,7 +269,7 @@ def jacobian_fun(self, frame: str): def fun(base_transform, joint_positions): return self.rbdalgos.jacobian(frame, base_transform, joint_positions).array - vmapped_fun = jax.vmap(fun) + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) jit_vmapped_fun = jax.jit(vmapped_fun) self.funcs[f"jacobian_{frame}"] = jax2torch(jit_vmapped_fun) return self.funcs[f"jacobian_{frame}"] @@ -398,7 +398,7 @@ def fun(base_transform, joint_positions): self.g, ).array.squeeze() - vmapped_fun = jax.vmap(fun) + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) jit_vmapped_fun = jax.jit(vmapped_fun) self.funcs["gravity_term"] = jax2torch(jit_vmapped_fun) return self.funcs["gravity_term"] @@ -430,7 +430,7 @@ def CoM_position_fun(self): def fun(base_transform, joint_positions): return self.rbdalgos.CoM_position(base_transform, joint_positions).array - vmapped_fun = jax.vmap(fun) + vmapped_fun = jax.vmap(fun, in_axes=(0, 0)) jit_vmapped_fun = jax.jit(vmapped_fun) self.funcs["CoM_position"] = jax2torch(jit_vmapped_fun) return self.funcs["CoM_position"]