Skip to content

Commit

Permalink
Inference bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
georgematheos authored Sep 13, 2024
1 parent b68cab6 commit d6faa72
Show file tree
Hide file tree
Showing 4 changed files with 470 additions and 21 deletions.
47 changes: 33 additions & 14 deletions notebooks/bayes3d_paper/tester.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 49/49 [00:03<00:00, 13.47it/s]\n",
"100%|██████████| 49/49 [00:03<00:00, 13.41it/s]\n",
"/home/georgematheos/b3d/.pixi/envs/gpu/lib/python3.12/site-packages/torch/utils/cpp_extension.py:1967: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n",
"If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].\n",
" warnings.warn(\n"
Expand Down Expand Up @@ -440,7 +440,7 @@
"metadata": {},
"outputs": [],
"source": [
"b3d.rr_init(\"inference_given_gtpose_4\")"
"b3d.rr_init(\"inference_given_gtpose_5\")"
]
},
{
Expand All @@ -454,23 +454,23 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"inference_hyperparams = i.InferenceHyperparams(\n",
" n_poses=1500,\n",
" do_stochastic_color_proposals=False,\n",
" do_stochastic_color_proposals=True,\n",
" pose_proposal_std=0.04,\n",
" pose_proposal_conc=1000.,\n",
" prev_color_proposal_laplace_scale=.04,\n",
" obs_color_proposal_laplace_scale=.01,\n",
" obs_color_proposal_laplace_scale=.02,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 35,
"metadata": {},
"outputs": [
{
Expand All @@ -484,9 +484,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/home/georgematheos/b3d/src/b3d/modeling_utils.py:86: UserWarning: RenormalizedLaplace sampling is currently not implemented perfectly.\n",
" warnings.warn(\n",
"100%|██████████| 30/30 [00:32<00:00, 1.09s/it]\n"
"100%|██████████| 30/30 [00:09<00:00, 3.02it/s]\n"
]
}
],
Expand Down Expand Up @@ -539,11 +537,11 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"b3d.rr_init(\"real_inference2\")"
"b3d.rr_init(\"real_inference_3\")"
]
},
{
Expand All @@ -562,7 +560,11 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 20/20 [01:51<00:00, 5.57s/it]\n"
"/home/georgematheos/b3d/src/b3d/modeling_utils.py:86: UserWarning: RenormalizedLaplace sampling is currently not implemented perfectly.\n",
" warnings.warn(\n",
"/home/georgematheos/b3d/src/b3d/modeling_utils.py:86: UserWarning: RenormalizedLaplace sampling is currently not implemented perfectly.\n",
" warnings.warn(\n",
"100%|██████████| 20/20 [02:35<00:00, 7.75s/it]\n"
]
}
],
Expand Down Expand Up @@ -593,7 +595,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 29,
"metadata": {},
"outputs": [
{
Expand All @@ -607,12 +609,29 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 29/29 [02:40<00:00, 5.55s/it]\n"
" 14%|█▍ | 4/29 [00:40<04:16, 10.25s/it]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[29], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m T \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m20\u001b[39m, \u001b[38;5;28mlen\u001b[39m(all_data))):\n\u001b[1;32m 5\u001b[0m key \u001b[38;5;241m=\u001b[39m b3d\u001b[38;5;241m.\u001b[39msplit_key(key)\n\u001b[0;32m----> 6\u001b[0m trace \u001b[38;5;241m=\u001b[39m \u001b[43mi\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minference_step_c2f\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# number of sequential iterations of the parallel pose proposal to consider\u001b[39;49;00m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m5000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# number of poses to propose in parallel\u001b[39;49;00m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# So the total number of poses considered at each step of C2F is 5000 * 1\u001b[39;49;00m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mall_data\u001b[49m\u001b[43m[\u001b[49m\u001b[43mT\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mrgbd\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43mprev_color_proposal_laplace_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_hyperparams\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprev_color_proposal_laplace_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 13\u001b[0m \u001b[43m \u001b[49m\u001b[43mobs_color_proposal_laplace_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minference_hyperparams\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mobs_color_proposal_laplace_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 14\u001b[0m \u001b[43m \u001b[49m\u001b[43mdo_stochastic_color_proposals\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[1;32m 15\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m b3d\u001b[38;5;241m.\u001b[39mchisight\u001b[38;5;241m.\u001b[39mgen3d\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mviz_trace(\n\u001b[1;32m 17\u001b[0m trace,\n\u001b[1;32m 18\u001b[0m T,\n\u001b[1;32m 19\u001b[0m ground_truth_vertices\u001b[38;5;241m=\u001b[39mmeshes[OBJECT_INDEX]\u001b[38;5;241m.\u001b[39mvertices,\n\u001b[1;32m 20\u001b[0m ground_truth_pose\u001b[38;5;241m=\u001b[39mall_data[T][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcamera_pose\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39minv() \u001b[38;5;241m@\u001b[39m all_data[T][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mobject_poses\u001b[39m\u001b[38;5;124m\"\u001b[39m][OBJECT_INDEX]\n\u001b[1;32m 21\u001b[0m )\n",
"File \u001b[0;32m~/b3d/src/b3d/chisight/gen3d/inference.py:100\u001b[0m, in \u001b[0;36minference_step_c2f\u001b[0;34m(key, n_seq, n_poses_per_sequential_step, old_trace, observed_rgbd, *args, **kwargs)\u001b[0m\n\u001b[1;32m 98\u001b[0m k1, k2 \u001b[38;5;241m=\u001b[39m split(key)\n\u001b[1;32m 99\u001b[0m trace \u001b[38;5;241m=\u001b[39m advance_time(k1, old_trace, observed_rgbd)\n\u001b[0;32m--> 100\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minfer_latents_c2f\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 101\u001b[0m \u001b[43m \u001b[49m\u001b[43mk2\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_seq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_poses_per_sequential_step\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 102\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/b3d/src/b3d/chisight/gen3d/inference.py:121\u001b[0m, in \u001b[0;36minfer_latents_c2f\u001b[0;34m(key, n_seq, n_poses_per_sequential_step, trace, pose_proposal_std_conc_seq, **inference_hyperparam_kwargs)\u001b[0m\n\u001b[1;32m 114\u001b[0m inference_hyperparams \u001b[38;5;241m=\u001b[39m InferenceHyperparams(\n\u001b[1;32m 115\u001b[0m n_poses\u001b[38;5;241m=\u001b[39mn_poses_per_sequential_step,\n\u001b[1;32m 116\u001b[0m pose_proposal_std\u001b[38;5;241m=\u001b[39mstd,\n\u001b[1;32m 117\u001b[0m pose_proposal_conc\u001b[38;5;241m=\u001b[39mconc,\n\u001b[1;32m 118\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39minference_hyperparam_kwargs,\n\u001b[1;32m 119\u001b[0m )\n\u001b[1;32m 120\u001b[0m key, _ \u001b[38;5;241m=\u001b[39m split(key)\n\u001b[0;32m--> 121\u001b[0m trace, _ \u001b[38;5;241m=\u001b[39m \u001b[43minfer_latents_using_sequential_proposals\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 122\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_seq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minference_hyperparams\u001b[49m\n\u001b[1;32m 123\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 125\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trace\n",
"File \u001b[0;32m~/b3d/src/b3d/chisight/gen3d/inference.py:153\u001b[0m, in \u001b[0;36minfer_latents_using_sequential_proposals\u001b[0;34m(key, n_seq, trace, inference_hyperparams)\u001b[0m\n\u001b[1;32m 151\u001b[0m k1, k2 \u001b[38;5;241m=\u001b[39m split(key)\n\u001b[1;32m 152\u001b[0m ks \u001b[38;5;241m=\u001b[39m split(k1, n_seq)\n\u001b[0;32m--> 153\u001b[0m weights \u001b[38;5;241m=\u001b[39m [\u001b[43mget_weight\u001b[49m\u001b[43m(\u001b[49m\u001b[43mk\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m ks]\n\u001b[1;32m 155\u001b[0m normalized_logps \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mlog_softmax(jnp\u001b[38;5;241m.\u001b[39marray(weights))\n\u001b[1;32m 156\u001b[0m chosen_idx \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mcategorical(k2, normalized_logps)\n",
"File \u001b[0;32m~/b3d/src/b3d/chisight/gen3d/inference.py:149\u001b[0m, in \u001b[0;36minfer_latents_using_sequential_proposals.<locals>.get_weight\u001b[0;34m(key)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_weight\u001b[39m(key):\n\u001b[0;32m--> 149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minfer_latents\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mshared_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mget_trace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mget_metadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"## Finish the run\n",
"key = jax.random.PRNGKey(1234)\n",
"trace = trace_20\n",
"for T in tqdm(range(20, len(all_data))):\n",
" key = b3d.split_key(key)\n",
" trace = i.inference_step_c2f(\n",
Expand Down
432 changes: 432 additions & 0 deletions notebooks/bayes3d_paper/tester2.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions src/b3d/chisight/gen3d/image_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,11 @@ def logpdf(
)
# Points that don't hit the camera plane should not contribute to the score.
scores = jnp.where(is_unexplained(observed_rgbd_per_point), 0.0, scores)
score_for_pixels_with_points = scores.sum()

# TODO: add scoring for pixels that are not explained by the latent points
# TODO: add scores for pixels that don't get a point

return scores.sum()
return score_for_pixels_with_points

def get_rgbd_vertex_kernel(self) -> PixelRGBDDistribution:
# Note: The distributions were originally defined for per-pixel computation,
Expand Down
7 changes: 2 additions & 5 deletions src/b3d/chisight/gen3d/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@
propose_other_latents_given_pose,
propose_pose,
)
from b3d.chisight.gen3d.model import (
get_hypers,
get_prev_state,
)
from b3d.chisight.gen3d.model import get_hypers, get_new_state


@Pytree.dataclass
Expand Down Expand Up @@ -60,7 +57,7 @@ def advance_time(key, trace, observed_rgbd):
U.g(
(
Diff.no_change(get_hypers(trace)),
Diff.unknown_change(get_prev_state(trace)),
Diff.unknown_change(get_new_state(trace)),

This comment has been minimized.

Copy link
@horizon-blue

horizon-blue Sep 13, 2024

Contributor

image

),
C.kw(rgbd=observed_rgbd),
),
Expand Down

0 comments on commit d6faa72

Please sign in to comment.