From 7ce16ee3abc102a22a08e810c8d1c98dec936100 Mon Sep 17 00:00:00 2001 From: Giulio Romualdi Date: Wed, 8 Jan 2025 21:08:44 +0100 Subject: [PATCH] Refactor gravity initialization and update post-init conversion to use default tensor type --- src/adam/parametric/pytorch/computations_parametric.py | 4 +--- src/adam/pytorch/computations.py | 4 +--- src/adam/pytorch/torch_like.py | 6 +++--- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/adam/parametric/pytorch/computations_parametric.py b/src/adam/parametric/pytorch/computations_parametric.py index daa66d6..f969402 100644 --- a/src/adam/parametric/pytorch/computations_parametric.py +++ b/src/adam/parametric/pytorch/computations_parametric.py @@ -23,9 +23,7 @@ def __init__( joints_name_list: list, links_name_list: list, root_link: str = None, - gravity: np.array = torch.tensor( - [0, 0, -9.80665, 0, 0, 0], dtype=torch.float64 - ), + gravity: np.array = torch.tensor([0, 0, -9.80665, 0, 0, 0]), ) -> None: """ Args: diff --git a/src/adam/pytorch/computations.py b/src/adam/pytorch/computations.py index 3f39d49..6e753fc 100644 --- a/src/adam/pytorch/computations.py +++ b/src/adam/pytorch/computations.py @@ -20,9 +20,7 @@ def __init__( urdfstring: str, joints_name_list: list = None, root_link: str = None, - gravity: np.array = torch.tensor( - [0, 0, -9.80665, 0, 0, 0], dtype=torch.float64 - ), + gravity: np.array = torch.tensor([0, 0, -9.80665, 0, 0, 0]), ) -> None: """ Args: diff --git a/src/adam/pytorch/torch_like.py b/src/adam/pytorch/torch_like.py index de2cb81..342bf0d 100644 --- a/src/adam/pytorch/torch_like.py +++ b/src/adam/pytorch/torch_like.py @@ -17,9 +17,9 @@ class TorchLike(ArrayLike): array: torch.Tensor def __post_init__(self): - """Converts array to double precision""" - if self.array.dtype != torch.float64: - self.array = self.array.double() + """Converts array to the default type used in the library""" + if self.array.dtype != torch.get_default_dtype(): + self.array = self.array.to(torch.get_default_dtype()) def __setitem__(self, idx, value: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": """Overrides set item operator"""