Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a pzflow inform estimate demo to the examples #155

Merged
merged 11 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/smoke-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ jobs:
pip install .
pip install .[dev]
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: (Temporarily) roll back ceci to pre-v2
run: |
pip install ceci==1.17
- name: Run unit tests with pytest
run: |
python -m pytest tests
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/testing-and-coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
python-version: ['3.10', '3.11']

steps:
- uses: actions/checkout@v3
Expand Down
11 changes: 11 additions & 0 deletions examples/core_examples/pipe_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,26 @@ stages:
module_name: rail.creation.engines.flowEngine
name: flow_engine_test
nprocess: 1
aliases:
output: output_flow_engine_test
- classname: LSSTErrorModel
module_name: rail.creation.degraders.lsst_error_model
name: lsst_error_model_test
nprocess: 1
aliases:
input: output_flow_engine_test
output: output_lsst_error_model_test
- classname: ColumnMapper
module_name: rail.tools.table_tools
name: col_remapper_test
nprocess: 1
aliases:
input: output_lsst_error_model_test
output: output_col_remapper_test
- classname: TableConverter
module_name: rail.tools.table_tools
name: table_conv_test
nprocess: 1
aliases:
input: output_col_remapper_test
output: output_table_conv_test
11 changes: 0 additions & 11 deletions examples/core_examples/pipe_example_config.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
col_remapper_test:
aliases:
input: output_lsst_error_model_test
output: output_col_remapper_test
chunk_size: 100000
columns:
mag_g_lsst_err: mag_err_g_lsst
Expand All @@ -17,8 +14,6 @@ col_remapper_test:
name: col_remapper_test
output_mode: default
flow_engine_test:
aliases:
output: output_flow_engine_test
config: null
model: ${FLOWDIR}/pretrained_flow.pkl
n_samples: 50
Expand All @@ -36,9 +31,6 @@ lsst_error_model_test:
y: 23.73
z: 24.16
airmass: 1.2
aliases:
input: output_flow_engine_test
output: output_lsst_error_model_test
bandNames:
g: mag_g_lsst
i: mag_i_lsst
Expand Down Expand Up @@ -96,9 +88,6 @@ lsst_error_model_test:
z: 0.69
tvis: 30.0
table_conv_test:
aliases:
input: output_col_remapper_test
output: output_table_conv_test
config: null
input: None
name: table_conv_test
Expand Down
179 changes: 179 additions & 0 deletions examples/estimation_examples/pzflow_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "327d391f-58bc-4b6a-9bbe-3987b969c8f4",
"metadata": {},
"source": [
"PZFlow Informer and Estimator Demo\n",
"\n",
"Author: Tianqing Zhang\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "916a05ad",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import rail\n",
"from rail.core.data import TableHandle\n",
"from rail.core.stage import RailStage\n",
"import qp\n",
"import tables_io\n",
"\n",
"from rail.estimation.algos.pzflow_nf import PZFlowInformer, PZFlowEstimator\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8ef87d3",
"metadata": {},
"outputs": [],
"source": [
"DS = RailStage.data_store\n",
"DS.__class__.allow_overwrite = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f79c3a7b",
"metadata": {},
"outputs": [],
"source": [
"from rail.utils.path_utils import find_rail_file\n",
"trainFile = find_rail_file('examples_data/testdata/test_dc2_training_9816.hdf5')\n",
"testFile = find_rail_file('examples_data/testdata/test_dc2_validation_9816.hdf5')\n",
"training_data = DS.read_file(\"training_data\", TableHandle, trainFile)\n",
"test_data = DS.read_file(\"test_data\", TableHandle, testFile)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "756d78a3",
"metadata": {},
"outputs": [],
"source": [
"pzflow_dict = dict(hdf5_groupname='photometry',output_mode = 'not_fiducial' )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0857e6bb-18eb-4f89-bc4b-29bed1ffa122",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "1042a9f3",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# epoch = 200 gives a reasonable converged loss\n",
"pzflow_train = PZFlowInformer.make_stage(name='inform_pzflow',model='demo_pzflow.pkl',num_training_epochs = 30, **pzflow_dict)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c407f45b",
"metadata": {},
"outputs": [],
"source": [
"# training of the pzflow\n",
"pzflow_train.inform(training_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "156b6e3d",
"metadata": {},
"outputs": [],
"source": [
"pzflow_dict = dict(hdf5_groupname='photometry')\n",
"\n",
"pzflow_estimator = PZFlowEstimator.make_stage(name='estimate_pzflow',model='demo_pzflow.pkl',**pzflow_dict, chunk_size = 20000)"
sidneymau marked this conversation as resolved.
Show resolved Hide resolved
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00911d60",
"metadata": {},
"outputs": [],
"source": [
"# estimate using the test data\n",
"estimate_results = pzflow_estimator.estimate(test_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4cbdece3",
"metadata": {},
"outputs": [],
"source": [
"mode = estimate_results.read(force=True).ancil['zmode']\n",
"truth = np.array(test_data.data['photometry']['redshift'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba076bab-c5ab-4292-8de9-415e7b30af5c",
"metadata": {},
"outputs": [],
"source": [
"# visualize the prediction. \n",
"plt.figure(figsize = (8,8))\n",
"plt.scatter(truth, mode, s = 0.5)\n",
"plt.xlabel('True Redshift')\n",
"plt.ylabel('Mode of Estimated Redshift')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ed5bf266-3b5c-4d9b-8428-77a2833cafef",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading