Skip to content

Commit

Permalink
Creating test fixture to reset state of TF Keras graph before each te…
Browse files Browse the repository at this point in the history
…st (#513)

Co-authored-by: Ben Frederickson <[email protected]>
  • Loading branch information
gabrielspmoreira and benfred authored Jun 15, 2022
1 parent d99def0 commit 6127fe0
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 66 deletions.
5 changes: 5 additions & 0 deletions tests/unit/tf/_conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
ASSETS_DIR = pathlib.Path(__file__).parent.parent / "assets"


@pytest.fixture(autouse=True)
def tf_clear_session():
tf.keras.backend.clear_session()


@pytest.fixture
def tf_con_features():
features = {}
Expand Down
115 changes: 49 additions & 66 deletions tests/unit/tf/prediction_tasks/test_multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,19 @@ def test_model_with_multiple_tasks(music_streaming_data: Dataset, task_blocks):
metrics = model.train_step(ml.sample_batch(music_streaming_data, batch_size=50))

assert metrics["loss"] >= 0
assert len(metrics) == 9
# TODO: Investigate why metrics names change when multiple tests are run,
# but not when single tests run
# assert set(list(metrics.keys())) == set(
# [
# "loss",
# "regularization_loss",
# "click/binary_classification_task_loss",
# "play_percentage/regression_task_loss",
# "play_percentage/regression_task_root_mean_squared_error",
# "click/binary_classification_task_precision",
# "click/binary_classification_task_recall",
# "click/binary_classification_task_binary_accuracy",
# "click/binary_classification_task_auc",
# ]
# )
assert set(list(metrics.keys())) == set(
[
"loss",
"regularization_loss",
"click/binary_classification_task_loss",
"play_percentage/regression_task_loss",
"play_percentage/regression_task_root_mean_squared_error",
"click/binary_classification_task_precision",
"click/binary_classification_task_recall",
"click/binary_classification_task_binary_accuracy",
"click/binary_classification_task_auc",
]
)
if task_blocks:
assert model.prediction_tasks[0].task_block != model.prediction_tasks[1].task_block

Expand All @@ -58,31 +55,24 @@ def test_mmoe_head(music_streaming_data: Dataset):
metrics = model.train_step(ml.sample_batch(music_streaming_data, batch_size=50))

assert metrics["loss"] >= 0
assert len(metrics) == 14
# TODO: Investigate why metrics names change when multiple tests are run,
# but not when single tests run
# assert set(metrics.keys()) == set(
# [
# [
# [
# "loss",
# "click/binary_classification_task_loss",
# "like/binary_classification_task_loss",
# "play_percentage/regression_task_loss",
# "click/binary_classification_task_precision",
# "click/binary_classification_task_recall",
# "click/binary_classification_task_binary_accuracy",
# "click/binary_classification_task_auc",
# "like/binary_classification_task_precision_1",
# "like/binary_classification_task_recall_1",
# "like/binary_classification_task_binary_accuracy",
# "like/binary_classification_task_auc_1",
# "play_percentage/regression_task_root_mean_squared_error",
# "regularization_loss",
# ]
# ]
# ]
# )
assert set(metrics.keys()) == set(
[
"loss",
"click/binary_classification_task_loss",
"like/binary_classification_task_loss",
"play_percentage/regression_task_loss",
"click/binary_classification_task_precision",
"click/binary_classification_task_recall",
"click/binary_classification_task_binary_accuracy",
"click/binary_classification_task_auc",
"like/binary_classification_task_precision_1",
"like/binary_classification_task_recall_1",
"like/binary_classification_task_binary_accuracy",
"like/binary_classification_task_auc_1",
"play_percentage/regression_task_root_mean_squared_error",
"regularization_loss",
]
)


def test_ple_head(music_streaming_data: Dataset):
Expand All @@ -97,28 +87,21 @@ def test_ple_head(music_streaming_data: Dataset):
metrics = model.train_step(ml.sample_batch(music_streaming_data, batch_size=50))

assert metrics["loss"] >= 0
assert len(metrics) == 14
# TODO: Investigate why metrics names change when multiple tests are run,
# but not when single tests run
# assert set(metrics.keys()) == set(
# [
# [
# [
# "loss",
# "click/binary_classification_task_loss",
# "like/binary_classification_task_loss",
# "play_percentage/regression_task_loss",
# "click/binary_classification_task_precision",
# "click/binary_classification_task_recall",
# "click/binary_classification_task_binary_accuracy",
# "click/binary_classification_task_auc",
# "like/binary_classification_task_precision_1",
# "like/binary_classification_task_recall_1",
# "like/binary_classification_task_binary_accuracy",
# "like/binary_classification_task_auc_1",
# "play_percentage/regression_task_root_mean_squared_error",
# "regularization_loss",
# ]
# ]
# ]
# )
assert set(metrics.keys()) == set(
[
"loss",
"click/binary_classification_task_loss",
"like/binary_classification_task_loss",
"play_percentage/regression_task_loss",
"click/binary_classification_task_precision",
"click/binary_classification_task_recall",
"click/binary_classification_task_binary_accuracy",
"click/binary_classification_task_auc",
"like/binary_classification_task_precision_1",
"like/binary_classification_task_recall_1",
"like/binary_classification_task_binary_accuracy",
"like/binary_classification_task_auc_1",
"play_percentage/regression_task_root_mean_squared_error",
"regularization_loss",
]
)

0 comments on commit 6127fe0

Please sign in to comment.