Skip to content

Commit

Permalink
feat: add some comments to demo notebook.
Browse files Browse the repository at this point in the history
  • Loading branch information
callumtilbury committed Jul 17, 2024
1 parent fc42692 commit d34b816
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions examples/mixer_demonstration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,14 @@
"import flashbax as fbx\n",
"import jax.numpy as jnp\n",
"from jax.tree_util import tree_map\n",
"import jax"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"\n",
"key = jax.random.PRNGKey(0)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand All @@ -32,12 +25,13 @@
"TrajectoryBufferSample(experience={'acts': (4, 5, 3), 'obs': (4, 5, 2)})"
]
},
"execution_count": 7,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create our first buffer, with a sample batch size of 4\n",
"buffer_a = fbx.make_trajectory_buffer(\n",
" add_batch_size=1,\n",
" max_length_time_axis=1000,\n",
Expand All @@ -56,6 +50,7 @@
" timestep,\n",
")\n",
"for i in range(100):\n",
" # Fill with POSITIVE values\n",
" state_a = jax.jit(buffer_a.add, donate_argnums=0)(\n",
" state_a,\n",
" tree_map(lambda x, _i=i: (x * _i)[None, None, ...], timestep),\n",
Expand All @@ -67,7 +62,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -76,12 +71,13 @@
"TrajectoryBufferSample(experience={'acts': (16, 5, 3), 'obs': (16, 5, 2)})"
]
},
"execution_count": 8,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create our second buffer, with a sample batch size of 16\n",
"buffer_b = fbx.make_trajectory_buffer(\n",
" add_batch_size=1,\n",
" max_length_time_axis=1000,\n",
Expand All @@ -100,6 +96,7 @@
" timestep,\n",
")\n",
"for i in range(100):\n",
" # Fill with NEGATIVE values\n",
" state_b = jax.jit(buffer_b.add, donate_argnums=0)(\n",
" state_b,\n",
" tree_map(lambda x, _i=i: (- x * _i)[None, None, ...], timestep),\n",
Expand All @@ -111,21 +108,24 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Make the mixer, with a ratio of 1:3 from buffer_a:buffer_b\n",
"mixer = fbx.make_mixer(\n",
" buffers=[buffer_a, buffer_b],\n",
" sample_batch_size=8,\n",
" proportions=[1,3],\n",
")\n",
"\n",
"# jittable sampling!\n",
"mixer_sample = jax.jit(mixer.sample)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -134,23 +134,25 @@
"TrajectoryBufferSample(experience={'acts': (8, 5, 3), 'obs': (8, 5, 2)})"
]
},
"execution_count": 13,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Sample from the mixer, using the usual flashbax API\n",
"joint_sample = mixer_sample(\n",
" [state_a, state_b],\n",
" key,\n",
")\n",
"\n",
"# Notice the resulting shape\n",
"tree_map(lambda x: x.shape, joint_sample)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -179,7 +181,7 @@
" [60., 60.]]], dtype=float32)})"
]
},
"execution_count": 11,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -191,7 +193,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -268,7 +270,7 @@
" [-19., -19.]]], dtype=float32)})"
]
},
"execution_count": 12,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
Expand Down

0 comments on commit d34b816

Please sign in to comment.