Skip to content

Commit

Permalink
Merge pull request #44 from LSSTDESC/tqz/utils_refactor
Browse files Browse the repository at this point in the history
Tqz/utils refactor
  • Loading branch information
ztq1996 authored May 8, 2024
2 parents 2d51caa + e95ae06 commit 7ee9e42
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions examples/BPZ_lite_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"import desc_bpz\n",
"from rail.core.data import TableHandle\n",
"from rail.core.stage import RailStage\n",
"from rail.core.utils import RAILDIR\n",
"from rail.utils.path_utils import RAILDIR\n",
"from rail.estimation.algos.bpz_lite import BPZliteInformer, BPZliteEstimator"
]
},
Expand Down Expand Up @@ -693,7 +693,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.10.0"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions examples/BPZ_lite_with_custom_SEDs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
"metadata": {},
"outputs": [],
"source": [
"from rail.core.utils import RAILDIR\n",
"from rail.utils.path_utils import RAILDIR\n",
"import os\n",
"custom_data_path = RAILDIR + '/rail/examples_data/estimation_data/data'\n",
"sedpath = RAILDIR + '/rail/examples_data/estimation_data/data/SED'\n",
Expand Down Expand Up @@ -570,7 +570,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.0"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion src/rail/estimation/algos/bpz_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import tables_io
from ceci.config import StageParameter as Param
from rail.estimation.estimator import CatEstimator, CatInformer
from rail.core.utils import RAILDIR
from rail.utils.path_utils import RAILDIR
from rail.bpz.utils import RAIL_BPZ_DIR
from rail.core.common_params import SHARED_PARAMS

Expand Down
8 changes: 4 additions & 4 deletions tests/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import tables_io
from rail.core.stage import RailStage
from rail.core.data import DataStore, TableHandle
from rail.core.utils import RAILDIR
from rail.core.algo_utils import one_algo
from rail.utils.path_utils import RAILDIR
from rail.utils.testing_utils import one_algo
from rail.estimation.algos import bpz_lite
from rail.bpz.utils import RAIL_BPZ_DIR

Expand Down Expand Up @@ -71,11 +71,11 @@ def test_bpz_lite():
'hdf5_groupname': 'photometry',
'nt_array': [8],
'model': 'testmodel_bpz.pkl'}
zb_expected = np.array([0.16, 0.12, 0.14, 0.14, 0.06, 0.14, 0.12, 0.14, 0.06, 0.16])
zb_expected = np.array([0.16, 0.12, 0.0, 0.12, 0.05, 0.14, 0.11, 0.14, 0.05, 0.16])
train_algo = None
pz_algo = bpz_lite.BPZliteEstimator
results, rerun_results, rerun3_results = one_algo("BPZ_lite", train_algo, pz_algo, train_config_dict, estim_config_dict)
assert np.isclose(results.ancil['zmode'], zb_expected).all()
assert np.isclose(results.ancil['zmode'], zb_expected, atol=0.03).all()
assert np.isclose(results.ancil['zmode'], rerun_results.ancil['zmode']).all()

@pytest.mark.parametrize(
Expand Down

0 comments on commit 7ee9e42

Please sign in to comment.