Skip to content

Commit

Permalink
[doc] Update documentation for 3.0 (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Feb 28, 2024
1 parent 989ec0c commit b6d125f
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions docs/Billiards.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,12 @@
" dist = v_next.points - rename_dims(v_next.points, 'balls', 'others')\n",
" dist_norm = math.vec_length(dist, eps=1e-4) # eps to avoid NaN during backprop of sqrt\n",
" rel_v = v.values - rename_dims(v.values, 'balls', 'others')\n",
" dist_dir = math.divide_no_nan(dist, dist_norm)\n",
" dist_dir = math.safe_div(dist, dist_norm)\n",
" projected_v = dist_dir.vector * rel_v.vector\n",
" has_impact = (projected_v < 0) & (dist_norm < 2 * v.elements.radius)\n",
" impulse = -(1 + elasticity) * .5 * projected_v * dist_dir\n",
" radius_sum = v.elements.radius + rename_dims(v.elements.radius, 'balls', 'others')\n",
" impact_time = math.divide_no_nan(dist_norm - radius_sum, projected_v)\n",
" impact_time = math.safe_div(dist_norm - radius_sum, projected_v)\n",
" x_inc_contrib = math.sum(math.where(has_impact, math.minimum(impact_time - dt, 0) * impulse, 0), 'others')\n",
" v = v.with_elements(v.elements.shifted(x_inc_contrib))\n",
" v += math.sum(math.where(has_impact, impulse, 0), 'others')\n",
Expand Down Expand Up @@ -1237,7 +1237,7 @@
{
"cell_type": "code",
"source": [
"loss_grad = math.functional_gradient(loss_function, 'x0,v0')\n",
"loss_grad = math.gradient(loss_function, 'x0,v0')\n",
"x0 = vec(x=.1, y=.5)\n",
"v0 = vec(x=.3, y=0)\n",
"learning_rate = .01\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/Fluids_Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1805,7 +1805,7 @@
},
"outputs": [],
"source": [
"sim_grad = field.functional_gradient(simulate, wrt='velocity', get_output=False)"
"sim_grad = field.gradient(simulate, wrt='velocity', get_output=False)"
]
},
{
Expand Down Expand Up @@ -1917,7 +1917,7 @@
}
],
"source": [
"sim_grad = field.functional_gradient(simulate, wrt='velocity', get_output=True)\n",
"sim_grad = field.gradient(simulate, wrt='velocity', get_output=True)\n",
"\n",
"for opt_step in range(4):\n",
" (loss, final_smoke, _v), velocity_grad = sim_grad(initial_smoke, initial_velocity)\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/Known_Issues.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ This is because Jax requires all parameters including network weights to be decl


### Do no compute gradients (PyTorch)
Do not call `math.functional_gradient` within a jit-compiled function.
Do not call `math.gradient` within a jit-compiled function.
PyTorch cannot trace backward passes.


Expand Down
2 changes: 1 addition & 1 deletion docs/Learn_to_Throw_Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@
}
],
"source": [
"gradient = math.functional_gradient(loss_function, get_output=True)\n",
"gradient = math.gradient(loss_function, get_output=True)\n",
"gradient(vel)"
],
"metadata": {
Expand Down
4 changes: 2 additions & 2 deletions docs/Planets_Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@
],
"source": [
"v = math.rotate_vector(x, PI/2)\n",
"v = math.divide_no_nan(v, math.vec_length(v))\n",
"v = math.safe_div(v, math.vec_length(v))\n",
"plot(PointCloud(x, values=v, bounds=Box(x=(-2, 12), y=(-1, 13))), color=COLOR)"
]
},
Expand Down Expand Up @@ -269,7 +269,7 @@
"source": [
"def simulate(x, v, dt=.5):\n",
" dx = math.pairwise_distances(x)\n",
" a = .01 * math.sum(math.divide_no_nan(masses.planets.as_dual() * dx, math.vec_squared(dx) ** 1.5), '~planets')\n",
" a = .01 * math.sum(math.safe_div(masses.planets.as_dual() * dx, math.vec_squared(dx) ** 1.5), '~planets')\n",
" return x + v * dt, v + a * dt\n",
"\n",
"xs, vs = iterate(simulate, batch(time=100), x, v)\n",
Expand Down

0 comments on commit b6d125f

Please sign in to comment.