From 6127fe081beb6b171e251008e0a58b6f3e24091c Mon Sep 17 00:00:00 2001 From: Gabriel Moreira Date: Tue, 14 Jun 2022 21:02:43 -0300 Subject: [PATCH] Creating test fixture to reset state of TF Keras graph before each test (#513) Co-authored-by: Ben Frederickson --- tests/unit/tf/_conftest.py | 5 + .../tf/prediction_tasks/test_multi_task.py | 115 ++++++++---------- 2 files changed, 54 insertions(+), 66 deletions(-) diff --git a/tests/unit/tf/_conftest.py b/tests/unit/tf/_conftest.py index 6f881c15f3..c9e38ef002 100644 --- a/tests/unit/tf/_conftest.py +++ b/tests/unit/tf/_conftest.py @@ -25,6 +25,11 @@ ASSETS_DIR = pathlib.Path(__file__).parent.parent / "assets" +@pytest.fixture(autouse=True) +def tf_clear_session(): + tf.keras.backend.clear_session() + + @pytest.fixture def tf_con_features(): features = {} diff --git a/tests/unit/tf/prediction_tasks/test_multi_task.py b/tests/unit/tf/prediction_tasks/test_multi_task.py index e09a24a52b..29fd59467b 100644 --- a/tests/unit/tf/prediction_tasks/test_multi_task.py +++ b/tests/unit/tf/prediction_tasks/test_multi_task.py @@ -28,22 +28,19 @@ def test_model_with_multiple_tasks(music_streaming_data: Dataset, task_blocks): metrics = model.train_step(ml.sample_batch(music_streaming_data, batch_size=50)) assert metrics["loss"] >= 0 - assert len(metrics) == 9 - # TODO: Investigate why metrics names change when multiple tests are run, - # but not when single tests run - # assert set(list(metrics.keys())) == set( - # [ - # "loss", - # "regularization_loss", - # "click/binary_classification_task_loss", - # "play_percentage/regression_task_loss", - # "play_percentage/regression_task_root_mean_squared_error", - # "click/binary_classification_task_precision", - # "click/binary_classification_task_recall", - # "click/binary_classification_task_binary_accuracy", - # "click/binary_classification_task_auc", - # ] - # ) + assert set(list(metrics.keys())) == set( + [ + "loss", + "regularization_loss", + "click/binary_classification_task_loss", + "play_percentage/regression_task_loss", + "play_percentage/regression_task_root_mean_squared_error", + "click/binary_classification_task_precision", + "click/binary_classification_task_recall", + "click/binary_classification_task_binary_accuracy", + "click/binary_classification_task_auc", + ] + ) if task_blocks: assert model.prediction_tasks[0].task_block != model.prediction_tasks[1].task_block @@ -58,31 +55,24 @@ def test_mmoe_head(music_streaming_data: Dataset): metrics = model.train_step(ml.sample_batch(music_streaming_data, batch_size=50)) assert metrics["loss"] >= 0 - assert len(metrics) == 14 - # TODO: Investigate why metrics names change when multiple tests are run, - # but not when single tests run - # assert set(metrics.keys()) == set( - # [ - # [ - # [ - # "loss", - # "click/binary_classification_task_loss", - # "like/binary_classification_task_loss", - # "play_percentage/regression_task_loss", - # "click/binary_classification_task_precision", - # "click/binary_classification_task_recall", - # "click/binary_classification_task_binary_accuracy", - # "click/binary_classification_task_auc", - # "like/binary_classification_task_precision_1", - # "like/binary_classification_task_recall_1", - # "like/binary_classification_task_binary_accuracy", - # "like/binary_classification_task_auc_1", - # "play_percentage/regression_task_root_mean_squared_error", - # "regularization_loss", - # ] - # ] - # ] - # ) + assert set(metrics.keys()) == set( + [ + "loss", + "click/binary_classification_task_loss", + "like/binary_classification_task_loss", + "play_percentage/regression_task_loss", + "click/binary_classification_task_precision", + "click/binary_classification_task_recall", + "click/binary_classification_task_binary_accuracy", + "click/binary_classification_task_auc", + "like/binary_classification_task_precision_1", + "like/binary_classification_task_recall_1", + "like/binary_classification_task_binary_accuracy", + "like/binary_classification_task_auc_1", + "play_percentage/regression_task_root_mean_squared_error", + "regularization_loss", + ] + ) def test_ple_head(music_streaming_data: Dataset): @@ -97,28 +87,21 @@ def test_ple_head(music_streaming_data: Dataset): metrics = model.train_step(ml.sample_batch(music_streaming_data, batch_size=50)) assert metrics["loss"] >= 0 - assert len(metrics) == 14 - # TODO: Investigate why metrics names change when multiple tests are run, - # but not when single tests run - # assert set(metrics.keys()) == set( - # [ - # [ - # [ - # "loss", - # "click/binary_classification_task_loss", - # "like/binary_classification_task_loss", - # "play_percentage/regression_task_loss", - # "click/binary_classification_task_precision", - # "click/binary_classification_task_recall", - # "click/binary_classification_task_binary_accuracy", - # "click/binary_classification_task_auc", - # "like/binary_classification_task_precision_1", - # "like/binary_classification_task_recall_1", - # "like/binary_classification_task_binary_accuracy", - # "like/binary_classification_task_auc_1", - # "play_percentage/regression_task_root_mean_squared_error", - # "regularization_loss", - # ] - # ] - # ] - # ) + assert set(metrics.keys()) == set( + [ + "loss", + "click/binary_classification_task_loss", + "like/binary_classification_task_loss", + "play_percentage/regression_task_loss", + "click/binary_classification_task_precision", + "click/binary_classification_task_recall", + "click/binary_classification_task_binary_accuracy", + "click/binary_classification_task_auc", + "like/binary_classification_task_precision_1", + "like/binary_classification_task_recall_1", + "like/binary_classification_task_binary_accuracy", + "like/binary_classification_task_auc_1", + "play_percentage/regression_task_root_mean_squared_error", + "regularization_loss", + ] + )