Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ztq1996 committed Jul 30, 2024
1 parent b6953c8 commit fa50cdb
Showing 1 changed file with 184 additions and 0 deletions.
184 changes: 184 additions & 0 deletions examples/estimation_examples/pzflow_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "327d391f-58bc-4b6a-9bbe-3987b969c8f4",
"metadata": {},
"source": [
"PZFlow Informer and Estimator Demo\n",
"\n",
"Author: Tianqing Zhang\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "916a05ad",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import rail\n",
"from rail.core.data import TableHandle\n",
"from rail.core.stage import RailStage\n",
"import qp\n",
"import tables_io\n",
"\n",
"from rail.estimation.algos.pzflow_nf_default import PZFlowInformer, PZFlowEstimator\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8ef87d3",
"metadata": {},
"outputs": [],
"source": [
"DS = RailStage.data_store\n",
"DS.__class__.allow_overwrite = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f79c3a7b",
"metadata": {},
"outputs": [],
"source": [
"from rail.utils.path_utils import find_rail_file\n",
"trainFile = find_rail_file('examples_data/testdata/test_dc2_training_9816.hdf5')\n",
"testFile = find_rail_file('examples_data/testdata/test_dc2_validation_9816.hdf5')\n",
"training_data = DS.read_file(\"training_data\", TableHandle, trainFile)\n",
"test_data = DS.read_file(\"test_data\", TableHandle, testFile)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "756d78a3",
"metadata": {},
"outputs": [],
"source": [
"pzflow_dict = dict(hdf5_groupname='photometry')\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1042a9f3",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# epoch = 200 gives a reasonable converged loss\n",
"pzflow_train = PZFlowInformer.make_stage(name='inform_pzflow',model='demo_pzflow.pkl',num_training_epochs = 200, **pzflow_dict)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c4335359",
"metadata": {},
"outputs": [],
"source": [
"# delete after the ceci 2.0 pr merged\n",
"pzflow_train._aliases = dict()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c407f45b",
"metadata": {},
"outputs": [],
"source": [
"# training of the pzflow\n",
"pzflow_train.inform(training_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "156b6e3d",
"metadata": {},
"outputs": [],
"source": [
"pzflow_dict = dict(hdf5_groupname='photometry')\n",
"\n",
"pzflow_estimator = PZFlowEstimator.make_stage(name='estimate_pzflow',model='demo_pzflow.pkl',**pzflow_dict, chunk_size = 20000)\n",
"\n",
"pzflow_estimator._aliases = dict()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00911d60",
"metadata": {},
"outputs": [],
"source": [
"# estimate using the test data\n",
"estimate_results = pzflow_estimator.estimate(test_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4cbdece3",
"metadata": {},
"outputs": [],
"source": [
"mode = estimate_results.data.ancil['zmode']\n",
"truth = np.array(training_data.data['photometry']['redshift'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba076bab-c5ab-4292-8de9-415e7b30af5c",
"metadata": {},
"outputs": [],
"source": [
"# visualize the prediction. \n",
"plt.figure(figsize = (8,8))\n",
"plt.scatter(truth, mode, s = 0.5)\n",
"plt.xlabel('True Redshift')\n",
"plt.ylabel('Mode of Estimated Redshift')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed5bf266-3b5c-4d9b-8428-77a2833cafef",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit fa50cdb

Please sign in to comment.