Skip to content

Commit

Permalink
update precommit config
Browse files Browse the repository at this point in the history
  • Loading branch information
jjshoots committed Aug 8, 2024
1 parent 298f067 commit e5fe4c2
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ repos:
language: node
pass_filenames: false
types: [python]
additional_dependencies: ["pyright"]
additional_dependencies: ["pyright@latest"]
args:
- --project=pyproject.toml
5 changes: 4 additions & 1 deletion wingman/replay_buffer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,10 @@ def push(
"""
self.base_buffer.push(
data=self.unwrap_data(data, bulk),
data=self.unwrap_data(
wrapped_data=data,
bulk=bulk,
),
bulk=bulk,
)

Expand Down
8 changes: 4 additions & 4 deletions wingman/replay_buffer/flat_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _format_data(
# cast to the right dtype
data = np.asarray(
thing,
dtype=self.mode_dtype, # pyright: ignore[reportGeneralTypeIssues]
dtype=self.mode_dtype, # pyright: ignore[reportArgumentType, reportCallIssue]
)

# dim check
Expand All @@ -118,7 +118,7 @@ def _format_data(
data = torch.asarray(
thing,
device=self.storage_device,
dtype=self.mode_dtype, # pyright: ignore[reportGeneralTypeIssues]
dtype=self.mode_dtype, # pyright: ignore[reportArgumentType]
)
data.requires_grad_(False)

Expand Down Expand Up @@ -182,7 +182,7 @@ def push(
[
self.mode_caller.zeros(
(self.mem_size, *item.shape),
dtype=self.mode_dtype, # pyright: ignore[reportGeneralTypeIssues]
dtype=self.mode_dtype, # pyright: ignore[reportArgumentType, reportCallIssue]
)
for item in array_data
]
Expand All @@ -192,7 +192,7 @@ def push(
[
self.mode_caller.zeros(
(self.mem_size, *item.shape[1:]),
dtype=self.mode_dtype, # pyright: ignore[reportGeneralTypeIssues]
dtype=self.mode_dtype, # pyright: ignore[reportArgumentType, reportCallIssue]
)
for item in array_data
]
Expand Down
58 changes: 49 additions & 9 deletions wingman/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from typing import Any

import numpy as np

try:
Expand All @@ -17,37 +19,75 @@


def gpuize(
input,
input: np.ndarray | torch.Tensor,
device: str | torch.device = __device,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""gpuize.
Args:
----
input: the array that we want to gpuize
device: a string of the device we want to move the thing to
dtype: the datatype that the returned tensor should be
input (np.ndarray | torch.Tensor): the array that we want to gpuize
device (str | torch.device): a string of the device we want to move the thing to
dtype (torch.dtype): the datatype that the returned tensor should be
"""
if torch.is_tensor(input):
return input.to(device=device, dtype=dtype)
return input.to(device=device, dtype=dtype) # pyright: ignore[reportAttributeAccessIssue]
else:
return torch.tensor(input, device=device, dtype=dtype)


def cpuize(input) -> np.ndarray:
def nested_gpuize(
input: dict[str, Any],
device: str | torch.device = __device,
dtype: torch.dtype = torch.float32,
) -> dict[str, Any]:
"""Gpuize but for nested dictionaries of elements.
Args:
----
input (dict[str, Any]): the array that we want to gpuize
device (str | torch.device): a string of the device we want to move the thing to
dtype (torch.dtype): the datatype that the returned tensor should be
"""
for key, value in input.items():
if isinstance(value, dict):
input[key] = nested_gpuize(value, device=device, dtype=dtype)
else:
input[key] = gpuize(value)
return input


def cpuize(input: np.ndarray | torch.Tensor) -> np.ndarray:
"""cpuize.
Args:
----
input: the array of the thing we want to put on the cpu
input (np.ndarray | torch.Tensor): the array of the thing we want to put on the cpu
"""
if torch.is_tensor(input):
return input.detach().cpu().numpy()
return input.detach().cpu().numpy() # pyright: ignore[reportAttributeAccessIssue]
else:
return input
return input # pyright: ignore[reportReturnType]


def nested_cpuize(input: dict[str, Any]) -> dict[str, Any]:
"""Gpuize but for nested dictionaries of elements.
Args:
----
input (dict[str, Any]): the array of the thing we want to put on the cpu
"""
for key, value in input.items():
if isinstance(value, dict):
input[key] = nested_cpuize(value)
else:
input[key] = cpuize(value)
return input


def shutdown_handler(*_):
Expand Down

0 comments on commit e5fe4c2

Please sign in to comment.