Skip to content

Commit

Permalink
make sure units are int
Browse files Browse the repository at this point in the history
  • Loading branch information
scarlehoff committed Mar 4, 2024
1 parent 6a6688d commit 9cfdb2f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 38 deletions.
4 changes: 2 additions & 2 deletions n3fit/src/n3fit/model_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def layer_generator(i_layer, nodes_out, activation):
layer = base_layer_selector(
layer_type,
kernel_initializer=initializers,
units=nodes_out,
units=int(nodes_out),
activation=activation,
input_shape=(nodes_in,),
basis_size=basis_size,
Expand All @@ -755,7 +755,7 @@ def layer_generator(i_layer, nodes_out, activation):
layer_type,
replica_seeds=replica_seeds,
kernel_initializer=MetaLayer.select_initializer(initializer_name, seed=i_layer),
units=nodes_out,
units=int(nodes_out),
activation=activation,
is_first_layer=(i_layer == 0),
regularizer=reg,
Expand Down
75 changes: 39 additions & 36 deletions n3fit/src/n3fit/tests/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,42 +94,45 @@ def check_fit_results(
equal_checks = ["stop_epoch", "pos_state"]
approx_checks = ["erf_tr", "erf_vl", "chi2", "best_epoch", "best_epoch"]
relaxed_checks = ["arc_lengths", "integrability"]
for key, value in new_json.items():
reference = old_json[key]
err_msg = f"error for .json: {key}"
if key in equal_checks:
assert_equal(value, reference, err_msg=err_msg)
elif key in approx_checks:
assert_allclose(value, reference, err_msg=err_msg, rtol=rel_error)
elif key in relaxed_checks:
assert_allclose(value, reference, err_msg=err_msg, rtol=rel_error * 10)
elif key == "preprocessing":
for ref, cur in zip(reference, value):
err_msg += f" - {ref['fl']}"
assert_allclose(ref["smallx"], cur["smallx"], err_msg=err_msg, rtol=rel_error)
assert_allclose(ref["largex"], cur["largex"], err_msg=err_msg, rtol=rel_error)

# check that the times didnt grow in a weird manner
if timing:
# Better to catch up errors even when they happen to grow larger by chance
times = new_json["timing"]
fitting_time = times["walltime"]["replica_set_to_replica_fitted"]
assert fitting_time < EXPECTED_MAX_FITTIME

# For safety, check also the version
assert new_json["version"]["nnpdf"] == n3fit.__version__

new_expgrid = _load_exportgrid(new_expgrid_file)
old_expgrid = _load_exportgrid(old_expgrid_file)

# Now compare the exportgrids
for key, value in new_expgrid.items():
reference = old_expgrid[key]
err_msg = f"error for .exportgrid: {key}"
if key == "pdfgrid":
assert_allclose(value, reference, rtol=rel_error, atol=1e-6, err_msg=err_msg)
else:
assert_equal(value, reference, err_msg=err_msg)
try:
for key, value in new_json.items():
reference = old_json[key]
err_msg = f"error for .json: {key}"
if key in equal_checks:
assert_equal(value, reference, err_msg=err_msg)
elif key in approx_checks:
assert_allclose(value, reference, err_msg=err_msg, rtol=rel_error)
elif key in relaxed_checks:
assert_allclose(value, reference, err_msg=err_msg, rtol=rel_error * 10)
elif key == "preprocessing":
for ref, cur in zip(reference, value):
err_msg += f" - {ref['fl']}"
assert_allclose(ref["smallx"], cur["smallx"], err_msg=err_msg, rtol=rel_error)
assert_allclose(ref["largex"], cur["largex"], err_msg=err_msg, rtol=rel_error)

# check that the times didnt grow in a weird manner
if timing:
# Better to catch up errors even when they happen to grow larger by chance
times = new_json["timing"]
fitting_time = times["walltime"]["replica_set_to_replica_fitted"]
assert fitting_time < EXPECTED_MAX_FITTIME

# For safety, check also the version
assert new_json["version"]["nnpdf"] == n3fit.__version__

new_expgrid = _load_exportgrid(new_expgrid_file)
old_expgrid = _load_exportgrid(old_expgrid_file)

# Now compare the exportgrids
for key, value in new_expgrid.items():
reference = old_expgrid[key]
err_msg = f"error for .exportgrid: {key}"
if key == "pdfgrid":
assert_allclose(value, reference, rtol=rel_error, atol=1e-6, err_msg=err_msg)
else:
assert_equal(value, reference, err_msg=err_msg)
except:
import ipdb; ipdb.set_trace()


def _auxiliary_performfit(tmp_path, runcard=QUICKNAME, replica=1, timing=True, rel_error=2e-3):
Expand Down
1 change: 1 addition & 0 deletions n3fit/src/n3fit/tests/test_hyperopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def test_restart_from_pickle(tmp_path):

for i in range(n_trials_total):
# check that the files share exactly the same hyperopt history
import ipdb; ipdb.set_trace()
assert restart_json[i]['misc'] == direct_json[i]['misc']
assert restart_json[i]['state'] == direct_json[i]['state']
assert restart_json[i]['tid'] == direct_json[i]['tid']
Expand Down

0 comments on commit 9cfdb2f

Please sign in to comment.