diff --git a/examples/mixer_demonstration.ipynb b/examples/mixer_demonstration.ipynb index 15fe47b..7b07a40 100644 --- a/examples/mixer_demonstration.ipynb +++ b/examples/mixer_demonstration.ipynb @@ -9,21 +9,14 @@ "import flashbax as fbx\n", "import jax.numpy as jnp\n", "from jax.tree_util import tree_map\n", - "import jax" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ + "import jax\n", + "\n", "key = jax.random.PRNGKey(0)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -32,12 +25,13 @@ "TrajectoryBufferSample(experience={'acts': (4, 5, 3), 'obs': (4, 5, 2)})" ] }, - "execution_count": 7, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# Create our first buffer, with a sample batch size of 4\n", "buffer_a = fbx.make_trajectory_buffer(\n", " add_batch_size=1,\n", " max_length_time_axis=1000,\n", @@ -56,6 +50,7 @@ " timestep,\n", ")\n", "for i in range(100):\n", + " # Fill with POSITIVE values\n", " state_a = jax.jit(buffer_a.add, donate_argnums=0)(\n", " state_a,\n", " tree_map(lambda x, _i=i: (x * _i)[None, None, ...], timestep),\n", @@ -67,7 +62,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -76,12 +71,13 @@ "TrajectoryBufferSample(experience={'acts': (16, 5, 3), 'obs': (16, 5, 2)})" ] }, - "execution_count": 8, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# Create our second buffer, with a sample batch size of 16\n", "buffer_b = fbx.make_trajectory_buffer(\n", " add_batch_size=1,\n", " max_length_time_axis=1000,\n", @@ -100,6 +96,7 @@ " timestep,\n", ")\n", "for i in range(100):\n", + " # Fill with NEGATIVE values\n", " state_b = jax.jit(buffer_b.add, donate_argnums=0)(\n", " state_b,\n", " tree_map(lambda x, _i=i: (- x * _i)[None, None, ...], timestep),\n", @@ -111,21 +108,24 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ + "# Make the mixer, with a ratio of 1:3 from buffer_a:buffer_b\n", "mixer = fbx.make_mixer(\n", " buffers=[buffer_a, buffer_b],\n", " sample_batch_size=8,\n", " proportions=[1,3],\n", ")\n", + "\n", + "# jittable sampling!\n", "mixer_sample = jax.jit(mixer.sample)" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -134,23 +134,25 @@ "TrajectoryBufferSample(experience={'acts': (8, 5, 3), 'obs': (8, 5, 2)})" ] }, - "execution_count": 13, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "# Sample from the mixer, using the usual flashbax API\n", "joint_sample = mixer_sample(\n", " [state_a, state_b],\n", " key,\n", ")\n", "\n", + "# Notice the resulting shape\n", "tree_map(lambda x: x.shape, joint_sample)" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -179,7 +181,7 @@ " [60., 60.]]], dtype=float32)})" ] }, - "execution_count": 11, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -191,7 +193,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -268,7 +270,7 @@ " [-19., -19.]]], dtype=float32)})" ] }, - "execution_count": 12, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" }