From 382cf9953e2d1902ebb54662e215d62cc6aa7726 Mon Sep 17 00:00:00 2001 From: Eaad Date: Thu, 26 Sep 2024 16:40:31 +0300 Subject: [PATCH] Replace functools.partial with jax.tree_util.Partial Updated various buffer files to replace functools.partial with jax.tree_util.Partial for consistency and improved functionality. --- flashbax/buffers/mixer.py | 6 +++--- flashbax/buffers/prioritised_trajectory_buffer.py | 12 ++++++------ flashbax/buffers/trajectory_buffer.py | 10 +++++----- flashbax/buffers/trajectory_queue.py | 12 ++++++------ 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/flashbax/buffers/mixer.py b/flashbax/buffers/mixer.py index 95417a8..00f5cf8 100644 --- a/flashbax/buffers/mixer.py +++ b/flashbax/buffers/mixer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools +from jax.tree_util import Partial as partial from typing import Callable, Sequence, TypeVar import chex @@ -200,13 +200,13 @@ def make_mixer( # In case of rounding errors, add the remainder to the first buffer's proportion prop_batch_sizes[0] += sample_batch_size - sum(prop_batch_sizes) - mixer_sample_fn = functools.partial( + mixer_sample_fn = partial( sample_mixer_fn, prop_batch_sizes=prop_batch_sizes, sample_fns=sample_fns, ) - mixer_can_sample_fn = functools.partial( + mixer_can_sample_fn = partial( can_sample_mixer_fn, can_sample_fns=can_sample_fns, ) diff --git a/flashbax/buffers/prioritised_trajectory_buffer.py b/flashbax/buffers/prioritised_trajectory_buffer.py index 2337bf8..dd6ddab 100644 --- a/flashbax/buffers/prioritised_trajectory_buffer.py +++ b/flashbax/buffers/prioritised_trajectory_buffer.py @@ -21,7 +21,7 @@ """ -import functools +from jax.tree_util import Partial as partial import warnings from typing import TYPE_CHECKING, Callable, Generic, Optional, Tuple @@ -799,29 +799,29 @@ def make_prioritised_trajectory_buffer( max_length_time_axis = max_size // add_batch_size assert max_length_time_axis is not None - init_fn = functools.partial( + init_fn = partial( prioritised_init, add_batch_size=add_batch_size, max_length_time_axis=max_length_time_axis, period=period, ) - add_fn = functools.partial( + add_fn = partial( prioritised_add, sample_sequence_length=sample_sequence_length, period=period, device=device, ) - sample_fn = functools.partial( + sample_fn = partial( prioritised_sample, batch_size=sample_batch_size, sequence_length=sample_sequence_length, period=period, ) - can_sample_fn = functools.partial( + can_sample_fn = partial( can_sample, min_length_time_axis=min_length_time_axis ) - set_priorities_fn = functools.partial( + set_priorities_fn = partial( set_priorities, priority_exponent=priority_exponent, device=device ) diff --git a/flashbax/buffers/trajectory_buffer.py b/flashbax/buffers/trajectory_buffer.py index 2fba10f..706bd8d 100644 --- a/flashbax/buffers/trajectory_buffer.py +++ b/flashbax/buffers/trajectory_buffer.py @@ -19,7 +19,7 @@ This allows for random sampling of the trajectories within the buffer. """ -import functools +from jax.tree_util import Partial as partial import warnings from typing import TYPE_CHECKING, Callable, Generic, Optional, TypeVar @@ -586,21 +586,21 @@ def make_trajectory_buffer( max_length_time_axis = max_size // add_batch_size assert max_length_time_axis is not None - init_fn = functools.partial( + init_fn = partial( init, add_batch_size=add_batch_size, max_length_time_axis=max_length_time_axis, ) - add_fn = functools.partial( + add_fn = partial( add, ) - sample_fn = functools.partial( + sample_fn = partial( sample, batch_size=sample_batch_size, sequence_length=sample_sequence_length, period=period, ) - can_sample_fn = functools.partial( + can_sample_fn = partial( can_sample, min_length_time_axis=min_length_time_axis ) diff --git a/flashbax/buffers/trajectory_queue.py b/flashbax/buffers/trajectory_queue.py index becf1c4..30aa723 100644 --- a/flashbax/buffers/trajectory_queue.py +++ b/flashbax/buffers/trajectory_queue.py @@ -13,7 +13,7 @@ # limitations under the License. -import functools +from jax.tree_util import Partial as partial import warnings from typing import TYPE_CHECKING, Callable, Generic, Optional, Tuple, TypeVar @@ -386,24 +386,24 @@ def make_trajectory_queue( if max_size is not None: max_length_time_axis = max_size // add_batch_size - init_fn = functools.partial( + init_fn = partial( init, add_batch_size=add_batch_size, max_length_time_axis=max_length_time_axis, ) - add_fn = functools.partial( + add_fn = partial( add, ) - sample_fn = functools.partial( + sample_fn = partial( sample, sequence_length=sample_sequence_length, ) - can_sample_fn = functools.partial( + can_sample_fn = partial( can_sample, sample_sequence_length=sample_sequence_length, max_length_time_axis=max_length_time_axis, ) - can_add_fn = functools.partial( + can_add_fn = partial( can_add, add_sequence_length=add_sequence_length, max_length_time_axis=max_length_time_axis,