Skip to content

Commit

Permalink
Iterators working for DistToPoint and DistToDist. Not working for Poi…
Browse files Browse the repository at this point in the history
…ntToPoint. Need to address cumulative metrics for DistToPoint (i.e. metrics that require _all_ the data to be evaluated at once).
  • Loading branch information
drewoldag committed Nov 30, 2023
1 parent 7f196ea commit 61480e3
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 35 deletions.
15 changes: 9 additions & 6 deletions src/rail/evaluation/dist_to_dist_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from ceci.config import StageParameter as Param
from qp.metrics.base_metric_classes import DistToDistMetric
from qp.metrics.concrete_metric_classes import DistToDistMetric

from rail.core.data import Hdf5Handle, QPHandle
from rail.core.stage import RailStage
Expand Down Expand Up @@ -43,13 +43,16 @@ def __init__(self, args, comm=None):
def run(self):
print(f"Requested metrics: {self.config.metrics}")

estimate_iterator = self.get_handle('input').iterator()
reference_iterator = self.get_handle('truth').iterator()
estimate_iterator = self.input_iterator('input')
reference_iterator = self.input_iterator('truth')

first = True
for s, e, estimate_data, _, _, reference_data in zip(estimate_iterator, reference_iterator):
print(f"Processing {self.rank} running evaluator on chunk {s} - {e}.")
self._process_chunk(s, e, estimate_data, reference_data, first)
for estimate_data_chunk, reference_data_chunk in zip(estimate_iterator, reference_iterator):
chunk_start, chunk_end, estimate_data = estimate_data_chunk
_, _, reference_data = reference_data_chunk

print(f"Processing {self.rank} running evaluator on chunk {chunk_start} - {chunk_end}.")
self._process_chunk(chunk_start, chunk_end, estimate_data, reference_data, first)
first = False

self._output_handle.finalize_write()
Expand Down
15 changes: 10 additions & 5 deletions src/rail/evaluation/dist_to_point_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from ceci.config import StageParameter as Param
from qp.metrics.base_metric_classes import DistToPointMetric
from qp.metrics.concrete_metric_classes import DistToPointMetric

from rail.core.data import Hdf5Handle, QPHandle, TableHandle
from rail.core.stage import RailStage
Expand Down Expand Up @@ -34,6 +34,8 @@ class DistToPointEvaluator(Evaluator):
msg="Random seed value to use for reproducible results."),
hdf5_groupname=Param(str, "photometry", required=False,
msg="HDF5 Groupname for truth table."),
reference_dictionary_key=Param(str, "redshift", required=False,
msg="The key in the `truth` dictionary where the redshift data is stored."),
)
inputs = [('input', QPHandle),
('truth', TableHandle)]
Expand All @@ -51,9 +53,12 @@ def run(self):
reference_iterator = self.input_iterator('truth')

first = True
for s, e, estimate_data, _, _, reference_data in zip(estimate_iterator, reference_iterator):
print(f"Processing {self.rank} running evaluator on chunk {s} - {e}.")
self._process_chunk(s, e, estimate_data, reference_data, first)
for estimate_data_chunk, reference_data_chunk in zip(estimate_iterator, reference_iterator):
chunk_start, chunk_end, estimate_data = estimate_data_chunk
_, _, reference_data = reference_data_chunk

print(f"Processing {self.rank} running evaluator on chunk {chunk_start} - {chunk_end}.")
self._process_chunk(chunk_start, chunk_end, estimate_data, reference_data, first)
first = False

self._output_handle.finalize_write()
Expand All @@ -69,7 +74,7 @@ def _process_chunk(self, start, end, estimate_data, reference_data, first):
continue

this_metric = self._metric_dict[metric](**self.config.to_dict())
out_table[metric] = this_metric.evaluate(estimate_data, reference_data)
out_table[metric] = this_metric.evaluate(estimate_data, reference_data[self.config.reference_dictionary_key])

out_table_to_write = {key: np.array(val).astype(float) for key, val in out_table.items()}

Expand Down
13 changes: 8 additions & 5 deletions src/rail/evaluation/point_to_point_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ def __init__(self, args, comm=None):
def run(self):
print(f"Requested metrics: {self.config.metrics}")

estimate_iterator = self.get_handle('input').iterator()
reference_iterator = self.get_handle('truth').iterator()
estimate_iterator = self.input_iterator('input')
reference_iterator = self.input_iterator('truth')

first = True
for s, e, estimate_data, _, _, reference_data in zip(estimate_iterator, reference_iterator):
print(f"Processing {self.rank} running evaluator on chunk {s} - {e}.")
self._process_chunk(s, e, estimate_data, reference_data, first)
for estimate_data_chunk, reference_data_chunk in zip(estimate_iterator, reference_iterator):
chunk_start, chunk_end, estimate_data = estimate_data_chunk
_, _, reference_data = reference_data_chunk

print(f"Processing {self.rank} running evaluator on chunk {chunk_start} - {chunk_end}.")
self._process_chunk(chunk_start, chunk_end, estimate_data, reference_data, first)
first = False

self._output_handle.finalize_write()
Expand Down
116 changes: 97 additions & 19 deletions src/rail/evaluation/testing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
"metadata": {},
"outputs": [],
"source": [
"import tables_io\n",
"\n",
"from rail.evaluation.dist_to_dist_evaluator import DistToDistEvaluator\n",
"from rail.evaluation.dist_to_point_evaluator import DistToPointEvaluator\n",
"from rail.evaluation.point_to_point_evaluator import PointToPointEvaluator\n",
"from rail.core.stage import RailStage\n",
"from rail.core.data import QPHandle, TableHandle\n",
"\n",
Expand All @@ -15,20 +19,10 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"# 'cvm' takes about 3.5 minutes to run\n",
"# 'ad' takes about ~4 minutes to run\n",
"# 'ks' takes about 2.75 minutes to run\n",
"# 'kld' takes about X minutes to run\n",
"stage_dict = dict(\n",
" metrics=['cvm', 'ks', 'omega', 'kld'],\n",
" _random_state=None,\n",
")\n",
"squish_fish = DistToPointEvaluator.make_stage(name='SillyPoopfish', **stage_dict)\n"
"# Load example Data"
]
},
{
Expand Down Expand Up @@ -61,7 +55,15 @@
"outputs": [],
"source": [
"ensemble = DS.read_file(key='pdfs_data', handle_class=QPHandle, path=pdfs_file)\n",
"ztrue_data = DS.read_file('ztrue_data', TableHandle, ztrue_file)"
"ztrue_data = DS.read_file('ztrue_data', TableHandle, ztrue_file)\n",
"truth = DS.add_data('truth', ztrue_data()['photometry'], TableHandle, path=ztrue_file)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dist to Dist Evaluation"
]
},
{
Expand All @@ -70,8 +72,17 @@
"metadata": {},
"outputs": [],
"source": [
"truth = DS.add_data('truth', ztrue_data()['photometry'], TableHandle, path=ztrue_file)\n",
"# ensemble = DS.add_data('ensemble', fzdata(), QPHandle, path=pdfs_file)"
"# 'cvm' takes about 3.5 minutes to run\n",
"# 'ad' takes about ~4 minutes to run\n",
"# 'ks' takes about 2.75 minutes to run\n",
"# 'kld' takes about X minutes to run\n",
"\n",
"stage_dict = dict(\n",
" metrics=['cvm', 'ks', 'omega', 'kld'],\n",
" _random_state=None,\n",
")\n",
"\n",
"dtd_stage = DistToDistEvaluator.make_stage(name='SillyPoopfish', **stage_dict)"
]
},
{
Expand All @@ -80,7 +91,7 @@
"metadata": {},
"outputs": [],
"source": [
"squish_results = squish_fish.evaluate(ensemble, truth)"
"dtd_results = dtd_stage.evaluate(ensemble, ensemble)"
]
},
{
Expand All @@ -89,10 +100,77 @@
"metadata": {},
"outputs": [],
"source": [
"import tables_io\n",
"results_df= tables_io.convertObj(squish_results(), tables_io.types.PD_DATAFRAME)\n",
"results_df = tables_io.convertObj(dtd_results(), tables_io.types.PD_DATAFRAME)\n",
"results_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dist to Point Evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"stage_dict = dict(\n",
" metrics=['cdeloss'],\n",
" _random_state=None,\n",
")\n",
"dtp_stage = DistToPointEvaluator.make_stage(name='SillyPoopfish', **stage_dict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dtp_results = dtp_stage.evaluate(ensemble, truth)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"results_df = tables_io.convertObj(dtp_results(), tables_io.types.PD_DATAFRAME)\n",
"results_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Point to Point Evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"stage_dict = dict(\n",
" metrics=['point_stats_ez'],\n",
" _random_state=None,\n",
")\n",
"ptp_stage = PointToPointEvaluator.make_stage(name='SillyPoopfish', **stage_dict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ptp_results = ptp_stage.evaluate(truth, truth)"
]
}
],
"metadata": {
Expand All @@ -111,7 +189,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 61480e3

Please sign in to comment.