diff --git a/examples/anakin_dqn_example.ipynb b/examples/anakin_dqn_example.ipynb index 9c860ca..83b220c 100644 --- a/examples/anakin_dqn_example.ipynb +++ b/examples/anakin_dqn_example.ipynb @@ -363,9 +363,9 @@ "):\n", " \"\"\"Broadcasts parameters to device shape.\"\"\"\n", " broadcast = lambda x: jnp.broadcast_to(x, (cores_count, num_envs) + x.shape)\n", - " params = jax.tree_map(broadcast, params) # broadcast to cores and batch.\n", - " opt_state = jax.tree_map(broadcast, opt_state) # broadcast to cores and batch\n", - " buffer_state = jax.tree_map(broadcast, buffer_state) # broadcast to cores and batch\n", + " params = jax.tree.map(broadcast, params) # broadcast to cores and batch.\n", + " opt_state = jax.tree.map(broadcast, opt_state) # broadcast to cores and batch\n", + " buffer_state = jax.tree.map(broadcast, buffer_state) # broadcast to cores and batch\n", " params_state = Params(\n", " online=params,\n", " target=params,\n", @@ -494,7 +494,7 @@ "def eval(params, rng):\n", " \"\"\"Evaluates multiple episodes.\"\"\"\n", " rngs = random.split(rng, NUM_EVAL_EPISODES)\n", - " params = jax.tree_map(lambda x: x[0][0], params)\n", + " params = jax.tree.map(lambda x: x[0][0], params)\n", " _, tot_r = jax.lax.scan(eval_one_episode, params, rngs)\n", " return tot_r.mean()" ] @@ -535,8 +535,8 @@ "rng, *env_rngs = jax.random.split(rng, cores_count * NUM_ENVS + 1)\n", "env_states, env_timesteps = jax.vmap(env.reset)(jnp.stack(env_rngs)) # init envs.\n", "reshape = lambda x: x.reshape((cores_count, NUM_ENVS) + x.shape[1:])\n", - "env_states = jax.tree_map(reshape, env_states) # add dimension to pmap over.\n", - "env_timesteps = jax.tree_map(reshape, env_timesteps) # add dimension to pmap over.\n", + "env_states = jax.tree.map(reshape, env_states) # add dimension to pmap over.\n", + "env_timesteps = jax.tree.map(reshape, env_timesteps) # add dimension to pmap over.\n", "params_state, opt_state, buffer_state, step_rngs, rng = broadcast_to_device_shape(\n", " cores_count, NUM_ENVS, params, opt_state, buffer_state, rng\n", ")\n", @@ -562,7 +562,7 @@ " # Train\n", " start = timeit.default_timer()\n", " params_state, buffer_state, opt_state, step_rngs, env_states, env_timesteps = learn(params_state, buffer_state, opt_state, step_rngs, env_states, env_timesteps)\n", - " params_state = jax.tree_map(lambda x: x.block_until_ready(), params_state) # wait for params to be ready so time is accurate.\n", + " params_state = jax.tree.map(lambda x: x.block_until_ready(), params_state) # wait for params to be ready so time is accurate.\n", " total_time += timeit.default_timer() - start\n", " # Eval\n", " rng, eval_rng = jax.random.split(rng, num=2)\n", diff --git a/examples/anakin_ppo_example.ipynb b/examples/anakin_ppo_example.ipynb index d6d632d..ea0ed6c 100644 --- a/examples/anakin_ppo_example.ipynb +++ b/examples/anakin_ppo_example.ipynb @@ -168,7 +168,7 @@ " step_rngs = random.split(outer_rng, rollout_len)\n", " (env_state, env_timestep, params), rollout = lax.scan(step_fn, (env_state, env_timestep, params), step_rngs) \n", "\n", - " rollout = jax.tree_map(lambda x: jnp.expand_dims(x,0), rollout)\n", + " rollout = jax.tree.map(lambda x: jnp.expand_dims(x,0), rollout)\n", " \n", " return rollout, env_state, env_timestep\n", " \n", @@ -200,7 +200,7 @@ " buffer_state = buffer_fn.add(buffer_state, data_rollout) \n", " buffer_state, batch = buffer_fn.sample(buffer_state) \n", " # We get rid of the batch dimension here\n", - " batch = jax.tree_map(lambda x: jnp.squeeze(x, 0), batch.experience) \n", + " batch = jax.tree.map(lambda x: jnp.squeeze(x, 0), batch.experience) \n", " \n", " def epoch_update(carry, _):\n", " \"\"\"Updates the parameters of the agent.\"\"\"\n", @@ -289,9 +289,9 @@ "def broadcast_to_device_shape(cores_count, num_envs, params, opt_state, buffer_state, rng):\n", " \"\"\"Broadcasts the parameters to the shape of the device.\"\"\"\n", " broadcast = lambda x: jnp.broadcast_to(x, (cores_count, num_envs) + x.shape)\n", - " params = jax.tree_map(broadcast, params) # broadcast to cores and batch.\n", - " opt_state = jax.tree_map(broadcast, opt_state) # broadcast to cores and batch\n", - " buffer_state = jax.tree_map(broadcast, buffer_state) # broadcast to cores and batch\n", + " params = jax.tree.map(broadcast, params) # broadcast to cores and batch.\n", + " opt_state = jax.tree.map(broadcast, opt_state) # broadcast to cores and batch\n", + " buffer_state = jax.tree.map(broadcast, buffer_state) # broadcast to cores and batch\n", " params_state = Params(online=params, update_count=jnp.zeros(shape=(cores_count, num_envs)))\n", " rng, step_rngs = get_rng_keys(cores_count, num_envs, rng)\n", " return params_state, opt_state, buffer_state, step_rngs, rng" @@ -393,7 +393,7 @@ " \"\"\"Evaluates the agent on multiple episodes.\"\"\"\n", "\n", " rngs = random.split(rng, NUM_EVAL_EPISODES)\n", - " params = jax.tree_map(lambda x: x[0][0], params)\n", + " params = jax.tree.map(lambda x: x[0][0], params)\n", " _, tot_r = jax.lax.scan(eval_one_episode, params, rngs)\n", " return tot_r.mean()" ] @@ -431,8 +431,8 @@ "rng, *env_rngs = jax.random.split(rng, cores_count * NUM_ENVS + 1)\n", "env_states, env_timesteps = jax.vmap(env.reset)(jnp.stack(env_rngs)) # init envs.\n", "reshape = lambda x: x.reshape((cores_count, NUM_ENVS) + x.shape[1:])\n", - "env_states = jax.tree_map(reshape, env_states) # add dimension to pmap over.\n", - "env_timesteps = jax.tree_map(reshape, env_timesteps) # add dimension to pmap over.\n", + "env_states = jax.tree.map(reshape, env_states) # add dimension to pmap over.\n", + "env_timesteps = jax.tree.map(reshape, env_timesteps) # add dimension to pmap over.\n", "params_state, opt_state, buffer_state, step_rngs, rng = broadcast_to_device_shape(cores_count, NUM_ENVS, params, opt_state, buffer_state, rng)\n", "\n", "\n", @@ -445,7 +445,7 @@ " # Train\n", " start = timeit.default_timer()\n", " params_state, buffer_state, opt_state, step_rngs, env_states, env_timesteps = learn(params_state, buffer_state, opt_state, step_rngs, env_states, env_timesteps)\n", - " params_state = jax.tree_map(lambda x: x.block_until_ready(), params_state) # wait for params to be ready so time is accurate.\n", + " params_state = jax.tree.map(lambda x: x.block_until_ready(), params_state) # wait for params to be ready so time is accurate.\n", " total_time += timeit.default_timer() - start\n", " # Eval\n", " rng, eval_rng = jax.random.split(rng, num=2)\n", diff --git a/examples/anakin_prioritised_dqn_example.ipynb b/examples/anakin_prioritised_dqn_example.ipynb index 377a54f..dd13807 100644 --- a/examples/anakin_prioritised_dqn_example.ipynb +++ b/examples/anakin_prioritised_dqn_example.ipynb @@ -387,9 +387,9 @@ "):\n", " \"\"\"Broadcasts parameters to device shape.\"\"\"\n", " broadcast = lambda x: jnp.broadcast_to(x, (cores_count, num_envs) + x.shape)\n", - " params = jax.tree_map(broadcast, params) # broadcast to cores and batch.\n", - " opt_state = jax.tree_map(broadcast, opt_state) # broadcast to cores and batch\n", - " buffer_state = jax.tree_map(broadcast, buffer_state) # broadcast to cores and batch\n", + " params = jax.tree.map(broadcast, params) # broadcast to cores and batch.\n", + " opt_state = jax.tree.map(broadcast, opt_state) # broadcast to cores and batch\n", + " buffer_state = jax.tree.map(broadcast, buffer_state) # broadcast to cores and batch\n", " params_state = Params(\n", " online=params,\n", " target=params,\n", @@ -521,7 +521,7 @@ "def eval(params, rng):\n", " \"\"\"Evaluates multiple episodes.\"\"\"\n", " rngs = random.split(rng, NUM_EVAL_EPISODES)\n", - " params = jax.tree_map(lambda x: x[0][0], params)\n", + " params = jax.tree.map(lambda x: x[0][0], params)\n", " _, tot_r = jax.lax.scan(eval_one_episode, params, rngs)\n", " return tot_r.mean()" ] @@ -580,8 +580,8 @@ "rng, *env_rngs = jax.random.split(rng, cores_count * NUM_ENVS + 1)\n", "env_states, env_timesteps = jax.vmap(env.reset)(jnp.stack(env_rngs)) # init envs.\n", "reshape = lambda x: x.reshape((cores_count, NUM_ENVS) + x.shape[1:])\n", - "env_states = jax.tree_map(reshape, env_states) # add dimension to pmap over.\n", - "env_timesteps = jax.tree_map(reshape, env_timesteps) # add dimension to pmap over.\n", + "env_states = jax.tree.map(reshape, env_states) # add dimension to pmap over.\n", + "env_timesteps = jax.tree.map(reshape, env_timesteps) # add dimension to pmap over.\n", "params_state, opt_state, buffer_state, step_rngs, rng = broadcast_to_device_shape(\n", " cores_count, NUM_ENVS, params, opt_state, buffer_state, rng\n", ")\n", @@ -608,7 +608,7 @@ " # Train\n", " start = timeit.default_timer()\n", " params_state, buffer_state, opt_state, step_rngs, env_states, env_timesteps = learn(params_state, buffer_state, opt_state, step_rngs, env_states, env_timesteps)\n", - " params_state = jax.tree_map(lambda x: x.block_until_ready(), params_state) # wait for params to be ready so time is accurate.\n", + " params_state = jax.tree.map(lambda x: x.block_until_ready(), params_state) # wait for params to be ready so time is accurate.\n", " total_time += timeit.default_timer() - start\n", " # Eval\n", " rng, eval_rng = jax.random.split(rng, num=2)\n", diff --git a/examples/quickstart_flat_buffer.ipynb b/examples/quickstart_flat_buffer.ipynb index b47a768..3d94351 100644 --- a/examples/quickstart_flat_buffer.ipynb +++ b/examples/quickstart_flat_buffer.ipynb @@ -168,7 +168,7 @@ "metadata": {}, "outputs": [], "source": [ - "fake_batch = jax.tree_map(lambda x: jnp.stack([x + i for i in range(add_batch_size)]), fake_timestep)\n", + "fake_batch = jax.tree.map(lambda x: jnp.stack([x + i for i in range(add_batch_size)]), fake_timestep)\n", "state = buffer.add(state, fake_batch)\n", "assert not buffer.can_sample(state) # Buffer is not ready to sample\n", "state = buffer.add(state, fake_batch)\n", @@ -287,7 +287,7 @@ "source": [ "# Define a function to create a fake batch of data\n", "def get_fake_batch(fake_timestep: chex.ArrayTree, batch_size) -> chex.ArrayTree:\n", - " return jax.tree_map(lambda x: jnp.stack([x + i for i in range(batch_size)]), fake_timestep)\n", + " return jax.tree.map(lambda x: jnp.stack([x + i for i in range(batch_size)]), fake_timestep)\n", "\n", "add_batch_size = 8\n", "\n", @@ -295,7 +295,7 @@ "buffer = fbx.make_flat_buffer(max_length, min_length, sample_batch_size, add_sequences, add_batch_size)\n", "\n", "# Initialize the buffer's state with a \"device\" dimension\n", - "fake_timestep_per_device = jax.tree_map(\n", + "fake_timestep_per_device = jax.tree.map(\n", " lambda x: jnp.stack([x + i for i in range(DEVICE_COUNT_MOCK)]), fake_timestep\n", ")\n", "state = jax.pmap(buffer.init)(fake_timestep_per_device)\n", diff --git a/examples/quickstart_prioritised_flat_buffer.ipynb b/examples/quickstart_prioritised_flat_buffer.ipynb index 62f3570..3d15d0c 100644 --- a/examples/quickstart_prioritised_flat_buffer.ipynb +++ b/examples/quickstart_prioritised_flat_buffer.ipynb @@ -154,7 +154,7 @@ "# The add function expects batches of experience - we create a fake batch by stacking\n", "# timesteps.\n", "# New samples to the buffer have their priority set to the maximum priority within the buffer. \n", - "fake_batch = jax.tree_map(lambda x: jnp.stack([x + i for i in range(add_batch_size)]),\n", + "fake_batch = jax.tree.map(lambda x: jnp.stack([x + i for i in range(add_batch_size)]),\n", " fake_timestep) \n", "state = buffer.add(state, fake_batch)\n", "assert buffer.can_sample(state) == False # After one batch the buffer is not yet full.\n", diff --git a/examples/quickstart_trajectory_buffer.ipynb b/examples/quickstart_trajectory_buffer.ipynb index 404054a..d5c024a 100644 --- a/examples/quickstart_trajectory_buffer.ipynb +++ b/examples/quickstart_trajectory_buffer.ipynb @@ -210,7 +210,7 @@ "# The add function expects batches of trajectories.\n", "# Thus, we create a fake batch of trajectories by broadcasting the `fake_timestep`.\n", "broadcast_fn = lambda x: jnp.broadcast_to(x, (add_batch_size, add_sequence_length, *x.shape))\n", - "fake_batch_sequence = jax.tree_map(broadcast_fn, fake_timestep)\n", + "fake_batch_sequence = jax.tree.map(broadcast_fn, fake_timestep)\n", "state = buffer.add(state, fake_batch_sequence)\n", "assert buffer.can_sample(state) == False # After one batch the buffer is not yet full.\n", "state = buffer.add(state, fake_batch_sequence)\n", diff --git a/flashbax/buffers/conftest.py b/flashbax/buffers/conftest.py index fbc8b5d..5423fd5 100644 --- a/flashbax/buffers/conftest.py +++ b/flashbax/buffers/conftest.py @@ -49,6 +49,6 @@ def fake_transition() -> chex.ArrayTree: def get_fake_batch(fake_transition: chex.ArrayTree, batch_size) -> chex.ArrayTree: """Create a fake batch with differing values for each transition.""" - return jax.tree_map( + return jax.tree.map( lambda x: jnp.stack([x + i for i in range(batch_size)]), fake_transition ) diff --git a/flashbax/buffers/flat_buffer_test.py b/flashbax/buffers/flat_buffer_test.py index a6d4a2d..3692c9b 100644 --- a/flashbax/buffers/flat_buffer_test.py +++ b/flashbax/buffers/flat_buffer_test.py @@ -109,7 +109,7 @@ def test_add_batch_size_none( max_length: int, ): # create a fake batch and ensure there is no batch dimension - fake_batch = jax.tree_map( + fake_batch = jax.tree.map( lambda x: jnp.squeeze(x, 0), get_fake_batch(fake_transition, 1) ) @@ -149,7 +149,7 @@ def test_add_sequences( ): add_sequence_size = 5 # create a fake sequence and ensure there is no batch dimension - fake_batch = jax.tree_map( + fake_batch = jax.tree.map( lambda x: x.repeat(add_sequence_size, axis=0), get_fake_batch(fake_transition, 1), ) @@ -184,7 +184,7 @@ def test_add_sequences_and_batches( ): add_sequence_size = 5 # create a fake batch and sequence - fake_batch = jax.tree_map( + fake_batch = jax.tree.map( lambda x: x[:, jnp.newaxis].repeat(add_sequence_size, axis=1), get_fake_batch(fake_transition, add_batch_size), ) @@ -228,7 +228,7 @@ def test_flat_replay_buffer_does_not_smoke( ) # Initialise the buffer's state. - fake_transition_per_device = jax.tree_map( + fake_transition_per_device = jax.tree.map( lambda x: jnp.stack([x + i for i in range(_DEVICE_COUNT_MOCK)]), fake_transition ) state = jax.pmap(buffer.init)(fake_transition_per_device) diff --git a/flashbax/buffers/item_buffer.py b/flashbax/buffers/item_buffer.py index 446b6cc..8f1da9d 100644 --- a/flashbax/buffers/item_buffer.py +++ b/flashbax/buffers/item_buffer.py @@ -90,7 +90,7 @@ def add_fn( ) -> TrajectoryBufferState[Experience]: """Flattens a batch to add items along single time axis.""" batch_size, seq_len = utils.get_tree_shape_prefix(batch, n_axes=2) - flattened_batch = jax.tree_map( + flattened_batch = jax.tree.map( lambda x: x.reshape((1, batch_size * seq_len, *x.shape[2:])), batch ) return buffer.add(state, flattened_batch) @@ -111,7 +111,7 @@ def sample_fn( ) -> TrajectoryBufferSample[Experience]: """Samples a batch of items from the buffer.""" sampled_batch = buffer.sample(state, rng_key).experience - sampled_batch = jax.tree_map(lambda x: x.squeeze(axis=1), sampled_batch) + sampled_batch = jax.tree.map(lambda x: x.squeeze(axis=1), sampled_batch) return TrajectoryBufferSample(experience=sampled_batch) return buffer.replace(add=add_fn, sample=sample_fn) # type: ignore diff --git a/flashbax/buffers/item_buffer_test.py b/flashbax/buffers/item_buffer_test.py index a9633bf..a159086 100644 --- a/flashbax/buffers/item_buffer_test.py +++ b/flashbax/buffers/item_buffer_test.py @@ -110,7 +110,7 @@ def test_add_batch_size_none( max_length: int, ): # create a fake batch and ensure there is no batch dimension - fake_batch = jax.tree_map( + fake_batch = jax.tree.map( lambda x: jnp.squeeze(x, 0), get_fake_batch(fake_transition, 1) ) @@ -148,7 +148,7 @@ def test_add_sequences( ): add_sequence_size = 5 # create a fake sequence and ensure there is no batch dimension - fake_batch = jax.tree_map( + fake_batch = jax.tree.map( lambda x: x.repeat(add_sequence_size, axis=0), get_fake_batch(fake_transition, 1), ) @@ -183,7 +183,7 @@ def test_add_sequences_and_batches( ): add_sequence_size = 5 # create a fake batch and sequence - fake_batch = jax.tree_map( + fake_batch = jax.tree.map( lambda x: x[:, jnp.newaxis].repeat(add_sequence_size, axis=1), get_fake_batch(fake_transition, add_batch_size), ) @@ -227,7 +227,7 @@ def test_item_replay_buffer_does_not_smoke( ) # Initialise the buffer's state. - fake_transition_per_device = jax.tree_map( + fake_transition_per_device = jax.tree.map( lambda x: jnp.stack([x + i for i in range(_DEVICE_COUNT_MOCK)]), fake_transition ) state = jax.pmap(buffer.init)(fake_transition_per_device) diff --git a/flashbax/buffers/prioritised_flat_buffer_test.py b/flashbax/buffers/prioritised_flat_buffer_test.py index 649161e..e2ec1de 100644 --- a/flashbax/buffers/prioritised_flat_buffer_test.py +++ b/flashbax/buffers/prioritised_flat_buffer_test.py @@ -187,7 +187,7 @@ def test_prioritised_flat_buffer_does_not_smoke( ) # Initialise the buffer's state. - fake_transition_per_device = jax.tree_map( + fake_transition_per_device = jax.tree.map( lambda x: jnp.stack([x + i for i in range(_DEVICE_COUNT_MOCK)]), fake_transition ) state = jax.pmap(buffer.init)(fake_transition_per_device) @@ -230,7 +230,7 @@ def test_add_batch_size_none( priority_exponent: float, ): # create a fake batch and ensure there is no batch dimension - fake_batch = jax.tree_map( + fake_batch = jax.tree.map( lambda x: jnp.squeeze(x, 0), get_fake_batch(fake_transition, 1) ) @@ -278,7 +278,7 @@ def test_add_sequences( ): add_sequence_size = 5 # create a fake sequence and ensure there is no batch dimension - fake_batch = jax.tree_map( + fake_batch = jax.tree.map( lambda x: x.repeat(add_sequence_size, axis=0), get_fake_batch(fake_transition, 1), ) @@ -321,7 +321,7 @@ def test_add_sequences_and_batches( ): add_sequence_size = 5 # create a fake batch and sequence - fake_batch = jax.tree_map( + fake_batch = jax.tree.map( lambda x: x[:, jnp.newaxis].repeat(add_sequence_size, axis=1), get_fake_batch(fake_transition, add_batch_size), ) diff --git a/flashbax/buffers/prioritised_item_buffer.py b/flashbax/buffers/prioritised_item_buffer.py index a78e304..b358f8b 100644 --- a/flashbax/buffers/prioritised_item_buffer.py +++ b/flashbax/buffers/prioritised_item_buffer.py @@ -83,7 +83,7 @@ def add_fn( ) -> PrioritisedTrajectoryBufferState[Experience]: """Flattens a batch to add items along single time axis.""" batch_size, seq_len = utils.get_tree_shape_prefix(batch, n_axes=2) - flattened_batch = jax.tree_map( + flattened_batch = jax.tree.map( lambda x: x.reshape((1, batch_size * seq_len, *x.shape[2:])), batch ) return buffer.add(state, flattened_batch) @@ -107,7 +107,7 @@ def sample_fn( priorities = sampled_batch.priorities indices = sampled_batch.indices sampled_batch = sampled_batch.experience - sampled_batch = jax.tree_map(lambda x: x.squeeze(axis=1), sampled_batch) + sampled_batch = jax.tree.map(lambda x: x.squeeze(axis=1), sampled_batch) return PrioritisedTrajectoryBufferSample( experience=sampled_batch, indices=indices, priorities=priorities ) diff --git a/flashbax/buffers/prioritised_item_buffer_test.py b/flashbax/buffers/prioritised_item_buffer_test.py index c7a4207..87be54b 100644 --- a/flashbax/buffers/prioritised_item_buffer_test.py +++ b/flashbax/buffers/prioritised_item_buffer_test.py @@ -160,7 +160,7 @@ def test_add_batch_size_none( max_length: int, ): # create a fake batch and ensure there is no batch dimension - fake_batch = jax.tree_map( + fake_batch = jax.tree.map( lambda x: jnp.squeeze(x, 0), get_fake_batch(fake_transition, 1) ) @@ -201,7 +201,7 @@ def test_add_sequences( ): add_sequence_size = 5 # create a fake sequence and ensure there is no batch dimension - fake_batch = jax.tree_map( + fake_batch = jax.tree.map( lambda x: x.repeat(add_sequence_size, axis=0), get_fake_batch(fake_transition, 1), ) @@ -236,7 +236,7 @@ def test_add_sequences_and_batches( ): add_sequence_size = 5 # create a fake batch and sequence - fake_batch = jax.tree_map( + fake_batch = jax.tree.map( lambda x: x[:, jnp.newaxis].repeat(add_sequence_size, axis=1), get_fake_batch(fake_transition, add_batch_size), ) @@ -281,7 +281,7 @@ def test_item_replay_buffer_does_not_smoke( ) # Initialise the buffer's state. - fake_transition_per_device = jax.tree_map( + fake_transition_per_device = jax.tree.map( lambda x: jnp.stack([x + i for i in range(_DEVICE_COUNT_MOCK)]), fake_transition ) state = jax.pmap(buffer.init)(fake_transition_per_device) diff --git a/flashbax/buffers/prioritised_trajectory_buffer_test.py b/flashbax/buffers/prioritised_trajectory_buffer_test.py index 0a194e0..5600247 100644 --- a/flashbax/buffers/prioritised_trajectory_buffer_test.py +++ b/flashbax/buffers/prioritised_trajectory_buffer_test.py @@ -204,7 +204,7 @@ def test_prioritised_sample_with_period( # Create a batch but specifically ensure that sequences in different add_batch rows # are distinct - this is simply for testing purposes in order to verify periodicity - fake_batch_sequence = jax.tree_map( + fake_batch_sequence = jax.tree.map( lambda x: jnp.stack([x + i * (max_length - 1) for i in range(add_batch_size)]), get_fake_batch(fake_transition, max_length - 1), ) @@ -318,7 +318,7 @@ def test_prioritised_trajectory_buffer_does_not_smoke( ) # Initialise the buffer's state. - fake_trajectory_per_device = jax.tree_map( + fake_trajectory_per_device = jax.tree.map( lambda x: jnp.stack([x + i for i in range(_DEVICE_COUNT_MOCK)]), fake_transition ) state = jax.pmap(buffer.init)(fake_trajectory_per_device) diff --git a/flashbax/buffers/trajectory_buffer.py b/flashbax/buffers/trajectory_buffer.py index 73b33b6..2fba10f 100644 --- a/flashbax/buffers/trajectory_buffer.py +++ b/flashbax/buffers/trajectory_buffer.py @@ -88,10 +88,10 @@ def init( been added yet. """ # Set experience value to be empty. - experience = jax.tree_map(jnp.empty_like, experience) + experience = jax.tree.map(jnp.empty_like, experience) # Broadcast to [add_batch_size, max_length_time_axis] - experience = jax.tree_map( + experience = jax.tree.map( lambda x: jnp.broadcast_to( x[None, None, ...], (add_batch_size, max_length_time_axis, *x.shape) ), diff --git a/flashbax/buffers/trajectory_buffer_test.py b/flashbax/buffers/trajectory_buffer_test.py index 3315fd0..908bb33 100644 --- a/flashbax/buffers/trajectory_buffer_test.py +++ b/flashbax/buffers/trajectory_buffer_test.py @@ -247,7 +247,7 @@ def test_sample_with_period( # Create a batch but specifically ensure that sequences in different add_batch rows # are distinct - this is simply for testing purposes in order to verify periodicity - fake_batch_sequence = jax.tree_map( + fake_batch_sequence = jax.tree.map( lambda x: jnp.stack([x + i * (min_length + 10) for i in range(add_batch_size)]), fake_sequence, ) @@ -303,7 +303,7 @@ def test_trajectory_buffer_does_not_smoke( ) # Initialise the buffer's state. - fake_trajectory_per_device = jax.tree_map( + fake_trajectory_per_device = jax.tree.map( lambda x: jnp.stack([x + i for i in range(_DEVICE_COUNT_MOCK)]), fake_transition ) state = jax.pmap(buffer.init)(fake_trajectory_per_device) diff --git a/flashbax/buffers/trajectory_queue.py b/flashbax/buffers/trajectory_queue.py index 35990f8..becf1c4 100644 --- a/flashbax/buffers/trajectory_queue.py +++ b/flashbax/buffers/trajectory_queue.py @@ -82,10 +82,10 @@ def init( been added yet. """ # Set experience value to be empty. - experience = jax.tree_map(jnp.empty_like, experience) + experience = jax.tree.map(jnp.empty_like, experience) # Broadcast to [add_batch_size, max_length_time_axis] - experience = jax.tree_map( + experience = jax.tree.map( lambda x: jnp.broadcast_to( x[None, None, ...], (add_batch_size, max_length_time_axis, *x.shape) ), diff --git a/flashbax/buffers/trajectory_queue_test.py b/flashbax/buffers/trajectory_queue_test.py index 3eee4b2..3acd329 100644 --- a/flashbax/buffers/trajectory_queue_test.py +++ b/flashbax/buffers/trajectory_queue_test.py @@ -421,7 +421,7 @@ def test_trajectory_queue_does_not_smoke( ) # Initialise the queue's state. - fake_trajectory_per_device = jax.tree_map( + fake_trajectory_per_device = jax.tree.map( lambda x: jnp.stack([x + i for i in range(_DEVICE_COUNT_MOCK)]), fake_transition ) state = jax.pmap(queue.init)(fake_trajectory_per_device) diff --git a/flashbax/utils.py b/flashbax/utils.py index 88ad680..bc34a15 100644 --- a/flashbax/utils.py +++ b/flashbax/utils.py @@ -59,12 +59,12 @@ def wrapper(*args, **kwargs): args = list(args) args[starting_arg_index:end_index] = [ - jax.tree_map(lambda x: jnp.expand_dims(x, axis=axis), a) + jax.tree.map(lambda x: jnp.expand_dims(x, axis=axis), a) for a in args[starting_arg_index:end_index] ] for k, v in kwargs.items(): if kwargs_on_device_keys is None or k in kwargs_on_device_keys: - kwargs[k] = jax.tree_map(lambda x: jnp.expand_dims(x, axis=1), v) + kwargs[k] = jax.tree.map(lambda x: jnp.expand_dims(x, axis=1), v) return func(*args, **kwargs) return wrapper diff --git a/flashbax/utils_test.py b/flashbax/utils_test.py index cb0693a..8bc5bc9 100644 --- a/flashbax/utils_test.py +++ b/flashbax/utils_test.py @@ -29,7 +29,7 @@ def fake_transition() -> chex.ArrayTree: def get_fake_batch(fake_transition: chex.ArrayTree, batch_size) -> chex.ArrayTree: """Create a tree with the same structure as `fake_transition` but with an extra batch axis. Each element across the batch dimension is different.""" - return jax.tree_map( + return jax.tree.map( lambda x: jnp.stack([x + i for i in range(batch_size)]), fake_transition ) diff --git a/flashbax/vault/vault.py b/flashbax/vault/vault.py index c494ba4..290c84a 100644 --- a/flashbax/vault/vault.py +++ b/flashbax/vault/vault.py @@ -127,9 +127,11 @@ def __init__( # noqa: CCR001 # Ensure provided metadata is json serialisable metadata_json_ready = jax.tree_util.tree_map( - lambda obj: str(obj) - if not isinstance(obj, (bool, str, int, float, type(None))) - else obj, + lambda obj: ( + str(obj) + if not isinstance(obj, (bool, str, int, float, type(None))) + else obj + ), metadata, ) @@ -137,11 +139,11 @@ def __init__( # noqa: CCR001 # each leaf. We will use this structure to map over the data stores later. # (Note: we use `jax.eval_shape` to get shape and dtype of each leaf, without # unnecessarily serialising the buffer data itself) - serialised_experience_structure_shape = jax.tree_map( + serialised_experience_structure_shape = jax.tree.map( lambda x: str(x.shape), serialize_tree(jax.eval_shape(lambda: experience_structure)), ) - serialised_experience_structure_dtype = jax.tree_map( + serialised_experience_structure_dtype = jax.tree.map( lambda x: x.dtype.name, serialize_tree(jax.eval_shape(lambda: experience_structure)), ) @@ -265,9 +267,11 @@ def _init_leaf( ) leaf_dtype = dtype spec["metadata"] = { - "compressor": COMPRESSION_DEFAULT - if self._compression is None - else self._compression + "compressor": ( + COMPRESSION_DEFAULT + if self._compression is None + else self._compression + ) } leaf_ds = ts.open( diff --git a/flashbax/vault/vault_test.py b/flashbax/vault/vault_test.py index 6a2bad7..f664917 100644 --- a/flashbax/vault/vault_test.py +++ b/flashbax/vault/vault_test.py @@ -190,7 +190,7 @@ def multiplier(x: Array, i: int): for i in range(0, max_length): buffer_state = buffer_add( buffer_state, - jax.tree_map(partial(multiplier, i=i), fake_transition), + jax.tree.map(partial(multiplier, i=i), fake_transition), ) v.write(buffer_state) diff --git a/pyproject.toml b/pyproject.toml index a2ef586..1f7f379 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,8 +35,8 @@ keywords=["reinforcement-learning", "python", "jax", "memory"] dependencies = [ 'flax>=0.6.11', 'chex>=0.1.8', - 'jax>=0.4.10', - 'jaxlib>=0.4.10', + 'jax>=0.4.25', + 'jaxlib>=0.4.20', 'numpy>=1.19.5', 'typing_extensions>=4.6.0', 'tensorstore>=0.1.51'