Skip to content

Commit

Permalink
Merge pull request #31 from LSSTDESC/user/aimalz/renaming
Browse files Browse the repository at this point in the history
naming consistency/clarity within src/rail/estimation
  • Loading branch information
aimalz authored Jul 14, 2023
2 parents 33ed647 + eb6e832 commit eca1dcb
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
8 changes: 4 additions & 4 deletions docs/notebooks/fzboost_pdf_representation_comparison.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@
"metadata": {},
"outputs": [],
"source": [
"from rail.estimation.algos.flexzboost import Inform_FZBoost, FZBoost\n",
"inform_pzflex = Inform_FZBoost.make_stage(name='inform_fzboost', model=fz_modelfile, **fz_dict)"
"from rail.estimation.algos.flexzboost import FlexZBoostInformer, FlexZBoostEstimator\n",
"inform_pzflex = FlexZBoostInformer.make_stage(name='inform_fzboost', model=fz_modelfile, **fz_dict)"
]
},
{
Expand Down Expand Up @@ -132,7 +132,7 @@
"metadata": {},
"outputs": [],
"source": [
"pzflex_qp_flexzboost = FZBoost.make_stage(name='fzboost_flexzboost', hdf5_groupname='photometry',\n",
"pzflex_qp_flexzboost = FlexZBoostEstimator.make_stage(name='fzboost_flexzboost', hdf5_groupname='photometry',\n",
" model=inform_pzflex.get_handle('model'),\n",
" output='flexzboost.hdf5',\n",
" qp_representation='flexzboost')"
Expand Down Expand Up @@ -258,7 +258,7 @@
"metadata": {},
"outputs": [],
"source": [
"pzflex_qp_interp = FZBoost.make_stage(name='fzboost_interp', hdf5_groupname='photometry',\n",
"pzflex_qp_interp = FlexZBoostEstimator.make_stage(name='fzboost_interp', hdf5_groupname='photometry',\n",
" model=inform_pzflex.get_handle('model'),\n",
" output='interp.hdf5',\n",
" qp_representation='interp')"
Expand Down
12 changes: 6 additions & 6 deletions src/rail/estimation/algos/flexzboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def make_color_data(data_dict, bands, err_bands, ref_band):
return input_data.T


class Inform_FZBoost(CatInformer):
""" Train a FZBoost CatEstimator
class FlexZBoostInformer(CatInformer):
""" Train a FlexZBoost CatInformer
"""
name = 'Inform_FZBoost'
name = 'FlexZBoostInformer'
config_options = CatInformer.config_options.copy()
config_options.update(zmin=SHARED_PARAMS,
zmax=SHARED_PARAMS,
Expand Down Expand Up @@ -190,10 +190,10 @@ def run(self):
self.add_data('model', self.model)


class FZBoost(CatEstimator):
"""FZBoost-based CatEstimator
class FlexZBoostEstimator(CatEstimator):
"""FlexZBoost-based CatEstimator
"""
name = 'FZBoost'
name = 'FlexZBoostEstimator'
config_options = CatEstimator.config_options.copy()
config_options.update(nzbins=SHARED_PARAMS,
nondetect_val=SHARED_PARAMS,
Expand Down
24 changes: 12 additions & 12 deletions tests/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def test_flexzboost():
'model': 'model.tmp'}
# zb_expected = np.array([0.13, 0.13, 0.13, 0.12, 0.12, 0.13, 0.12, 0.13,
# 0.12, 0.12])
train_algo = flexzboost.Inform_FZBoost
pz_algo = flexzboost.FZBoost
train_algo = flexzboost.FlexZBoostInformer
pz_algo = flexzboost.FlexZBoostEstimator
results, rerun_results, rerun3_results = one_algo("FZBoost", train_algo, pz_algo, train_config_dict, estim_config_dict)
# assert np.isclose(results.ancil['zmode'], zb_expected).all()
assert np.isclose(results.ancil['zmode'], rerun_results.ancil['zmode']).all()
Expand All @@ -54,8 +54,8 @@ def test_flexzboost_with_interp():
'qp_representation': 'interp'}
# zb_expected = np.array([0.13, 0.13, 0.13, 0.12, 0.12, 0.13, 0.12, 0.13,
# 0.12, 0.12])
train_algo = flexzboost.Inform_FZBoost
pz_algo = flexzboost.FZBoost
train_algo = flexzboost.FlexZBoostInformer
pz_algo = flexzboost.FlexZBoostEstimator
results, rerun_results, rerun3_results = one_algo("FZBoost", train_algo, pz_algo, train_config_dict, estim_config_dict)
# assert np.isclose(results.ancil['zmode'], zb_expected).all()
assert np.isclose(results.ancil['zmode'], rerun_results.ancil['zmode']).all()
Expand All @@ -77,8 +77,8 @@ def test_flexzboost_with_qp_flexzboost():
'qp_representation': 'flexzboost'}
# zb_expected = np.array([0.13, 0.13, 0.13, 0.12, 0.12, 0.13, 0.12, 0.13,
# 0.12, 0.12])
train_algo = flexzboost.Inform_FZBoost
pz_algo = flexzboost.FZBoost
train_algo = flexzboost.FlexZBoostInformer
pz_algo = flexzboost.FlexZBoostEstimator
results, rerun_results, rerun3_results = one_algo("FZBoost", train_algo, pz_algo, train_config_dict, estim_config_dict)
# assert np.isclose(results.ancil['zmode'], zb_expected).all()
assert np.isclose(results.ancil['zmode'], rerun_results.ancil['zmode']).all()
Expand All @@ -101,18 +101,18 @@ def test_flexzboost_with_unknown_qp_representation():
'qp_representation': 'bogus'}
# zb_expected = np.array([0.13, 0.13, 0.13, 0.12, 0.12, 0.13, 0.12, 0.13,
# 0.12, 0.12])
train_algo = flexzboost.Inform_FZBoost
pz_algo = flexzboost.FZBoost
train_algo = flexzboost.FlexZBoostInformer
pz_algo = flexzboost.FlexZBoostEstimator
with pytest.raises(ValueError) as excinfo:
one_algo("FZBoost", train_algo, pz_algo, train_config_dict, estim_config_dict)
assert "Unknown qp_representation" in str(excinfo.value)

def test_catch_bad_bands():
params = dict(bands='u,g,r,i,z,y')
with pytest.raises(ValueError):
flexzboost.Inform_FZBoost.make_stage(hdf5_groupname='', **params)
flexzboost.FlexZBoostInformer.make_stage(hdf5_groupname='', **params)
with pytest.raises(ValueError):
flexzboost.FZBoost.make_stage(hdf5_groupname='', **params)
flexzboost.FlexZBoostEstimator.make_stage(hdf5_groupname='', **params)


def test_missing_groupname_keyword():
Expand All @@ -126,7 +126,7 @@ def test_missing_groupname_keyword():
'objective':
'reg:squarederror'}}
with pytest.raises(ValueError):
_ = flexzboost.FZBoost.make_stage(**config_dict)
_ = flexzboost.FlexZBoostEstimator.make_stage(**config_dict)


def test_wrong_modelfile_keyword():
Expand All @@ -143,5 +143,5 @@ def test_wrong_modelfile_keyword():
'reg:squarederror'},
'model': 'nonexist.pkl'}
with pytest.raises(FileNotFoundError):
pz_algo = flexzboost.FZBoost.make_stage(**config_dict)
pz_algo = flexzboost.FlexZBoostEstimator.make_stage(**config_dict)
assert pz_algo.model is None

0 comments on commit eca1dcb

Please sign in to comment.