Skip to content

Commit

Permalink
Refactor gravity initialization and update post-init conversion to us…
Browse files Browse the repository at this point in the history
…e default tensor type
  • Loading branch information
GiulioRomualdi committed Jan 8, 2025
1 parent c5da859 commit 7ce16ee
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 9 deletions.
4 changes: 1 addition & 3 deletions src/adam/parametric/pytorch/computations_parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def __init__(
joints_name_list: list,
links_name_list: list,
root_link: str = None,
gravity: np.array = torch.tensor(
[0, 0, -9.80665, 0, 0, 0], dtype=torch.float64
),
gravity: np.array = torch.tensor([0, 0, -9.80665, 0, 0, 0]),
) -> None:
"""
Args:
Expand Down
4 changes: 1 addition & 3 deletions src/adam/pytorch/computations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def __init__(
urdfstring: str,
joints_name_list: list = None,
root_link: str = None,
gravity: np.array = torch.tensor(
[0, 0, -9.80665, 0, 0, 0], dtype=torch.float64
),
gravity: np.array = torch.tensor([0, 0, -9.80665, 0, 0, 0]),
) -> None:
"""
Args:
Expand Down
6 changes: 3 additions & 3 deletions src/adam/pytorch/torch_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ class TorchLike(ArrayLike):
array: torch.Tensor

def __post_init__(self):
"""Converts array to double precision"""
if self.array.dtype != torch.float64:
self.array = self.array.double()
"""Converts array to the default type used in the library"""
if self.array.dtype != torch.get_default_dtype():
self.array = self.array.to(torch.get_default_dtype())

def __setitem__(self, idx, value: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
"""Overrides set item operator"""
Expand Down

0 comments on commit 7ce16ee

Please sign in to comment.