diff --git a/cdt/causality/graph/SAM.py b/cdt/causality/graph/SAM.py index 39120b8..ee9c485 100644 --- a/cdt/causality/graph/SAM.py +++ b/cdt/causality/graph/SAM.py @@ -373,7 +373,9 @@ def run_SAM(in_data, skeleton=None, is_mixed=False, device="cpu", if not linear and functionalComplexity=="n_hidden_units": neuron_optimizer.step() - return output.div_(test).cpu().numpy() + test_additions = test * (data.shape[0] // batch_size) + + return output.div_(test_additions).cpu().numpy() # Evaluate total effect with final DAG diff --git a/tests/scripts/test_causality_graph.py b/tests/scripts/test_causality_graph.py index d92202d..8e397db 100644 --- a/tests/scripts/test_causality_graph.py +++ b/tests/scripts/test_causality_graph.py @@ -54,6 +54,13 @@ def test_SAM(): assert isinstance(m.predict(data_graph), nx.DiGraph) return 0 + +def test_SAM_batchsize(): + m = SAM(train_epochs=10, test_epochs=10, nh=10, dnh=10, nruns=1, njobs=1, batch_size=3, gpus=0) + assert nx.to_numpy_array(m.predict(data_graph)).max() <= 1 + return 0 + + def test_SAMv1(): m = SAMv1(train_epochs=10, test_epochs=10, nh=10, dnh=10, nruns=1, njobs=1) assert isinstance(m.predict(data_graph), nx.DiGraph) @@ -62,6 +69,7 @@ def test_SAMv1(): if __name__ == "__main__": test_SAM() + test_SAM_batchsize() test_SAMv1() # test_directed() # test_undirected()