diff --git a/n3fit/src/n3fit/model_gen.py b/n3fit/src/n3fit/model_gen.py index 5abe78f840..8c4c30c59a 100644 --- a/n3fit/src/n3fit/model_gen.py +++ b/n3fit/src/n3fit/model_gen.py @@ -737,7 +737,7 @@ def layer_generator(i_layer, nodes_out, activation): layer = base_layer_selector( layer_type, kernel_initializer=initializers, - units=nodes_out, + units=int(nodes_out), activation=activation, input_shape=(nodes_in,), basis_size=basis_size, @@ -755,7 +755,7 @@ def layer_generator(i_layer, nodes_out, activation): layer_type, replica_seeds=replica_seeds, kernel_initializer=MetaLayer.select_initializer(initializer_name, seed=i_layer), - units=nodes_out, + units=int(nodes_out), activation=activation, is_first_layer=(i_layer == 0), regularizer=reg, diff --git a/n3fit/src/n3fit/tests/test_fit.py b/n3fit/src/n3fit/tests/test_fit.py index 1ec36aa76c..c9224c6dd6 100644 --- a/n3fit/src/n3fit/tests/test_fit.py +++ b/n3fit/src/n3fit/tests/test_fit.py @@ -94,42 +94,45 @@ def check_fit_results( equal_checks = ["stop_epoch", "pos_state"] approx_checks = ["erf_tr", "erf_vl", "chi2", "best_epoch", "best_epoch"] relaxed_checks = ["arc_lengths", "integrability"] - for key, value in new_json.items(): - reference = old_json[key] - err_msg = f"error for .json: {key}" - if key in equal_checks: - assert_equal(value, reference, err_msg=err_msg) - elif key in approx_checks: - assert_allclose(value, reference, err_msg=err_msg, rtol=rel_error) - elif key in relaxed_checks: - assert_allclose(value, reference, err_msg=err_msg, rtol=rel_error * 10) - elif key == "preprocessing": - for ref, cur in zip(reference, value): - err_msg += f" - {ref['fl']}" - assert_allclose(ref["smallx"], cur["smallx"], err_msg=err_msg, rtol=rel_error) - assert_allclose(ref["largex"], cur["largex"], err_msg=err_msg, rtol=rel_error) - - # check that the times didnt grow in a weird manner - if timing: - # Better to catch up errors even when they happen to grow larger by chance - times = new_json["timing"] - fitting_time = times["walltime"]["replica_set_to_replica_fitted"] - assert fitting_time < EXPECTED_MAX_FITTIME - - # For safety, check also the version - assert new_json["version"]["nnpdf"] == n3fit.__version__ - - new_expgrid = _load_exportgrid(new_expgrid_file) - old_expgrid = _load_exportgrid(old_expgrid_file) - - # Now compare the exportgrids - for key, value in new_expgrid.items(): - reference = old_expgrid[key] - err_msg = f"error for .exportgrid: {key}" - if key == "pdfgrid": - assert_allclose(value, reference, rtol=rel_error, atol=1e-6, err_msg=err_msg) - else: - assert_equal(value, reference, err_msg=err_msg) + try: + for key, value in new_json.items(): + reference = old_json[key] + err_msg = f"error for .json: {key}" + if key in equal_checks: + assert_equal(value, reference, err_msg=err_msg) + elif key in approx_checks: + assert_allclose(value, reference, err_msg=err_msg, rtol=rel_error) + elif key in relaxed_checks: + assert_allclose(value, reference, err_msg=err_msg, rtol=rel_error * 10) + elif key == "preprocessing": + for ref, cur in zip(reference, value): + err_msg += f" - {ref['fl']}" + assert_allclose(ref["smallx"], cur["smallx"], err_msg=err_msg, rtol=rel_error) + assert_allclose(ref["largex"], cur["largex"], err_msg=err_msg, rtol=rel_error) + + # check that the times didnt grow in a weird manner + if timing: + # Better to catch up errors even when they happen to grow larger by chance + times = new_json["timing"] + fitting_time = times["walltime"]["replica_set_to_replica_fitted"] + assert fitting_time < EXPECTED_MAX_FITTIME + + # For safety, check also the version + assert new_json["version"]["nnpdf"] == n3fit.__version__ + + new_expgrid = _load_exportgrid(new_expgrid_file) + old_expgrid = _load_exportgrid(old_expgrid_file) + + # Now compare the exportgrids + for key, value in new_expgrid.items(): + reference = old_expgrid[key] + err_msg = f"error for .exportgrid: {key}" + if key == "pdfgrid": + assert_allclose(value, reference, rtol=rel_error, atol=1e-6, err_msg=err_msg) + else: + assert_equal(value, reference, err_msg=err_msg) + except: + import ipdb; ipdb.set_trace() def _auxiliary_performfit(tmp_path, runcard=QUICKNAME, replica=1, timing=True, rel_error=2e-3): diff --git a/n3fit/src/n3fit/tests/test_hyperopt.py b/n3fit/src/n3fit/tests/test_hyperopt.py index cecc747452..668de1aab5 100644 --- a/n3fit/src/n3fit/tests/test_hyperopt.py +++ b/n3fit/src/n3fit/tests/test_hyperopt.py @@ -75,6 +75,7 @@ def test_restart_from_pickle(tmp_path): for i in range(n_trials_total): # check that the files share exactly the same hyperopt history + import ipdb; ipdb.set_trace() assert restart_json[i]['misc'] == direct_json[i]['misc'] assert restart_json[i]['state'] == direct_json[i]['state'] assert restart_json[i]['tid'] == direct_json[i]['tid']