Skip to content

Commit

Permalink
Add custom metric for testing how many times a metric is called (#766)
Browse files Browse the repository at this point in the history
  • Loading branch information
oliverholworthy authored Sep 26, 2022
1 parent 96f31b8 commit 8fa182f
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions tests/unit/tf/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,29 @@ def on_batch_end(self, batch, logs=None):
self.epoch_logs[self.epoch].append(logs)


class UpdateCountMetric(tf.keras.metrics.Metric):
"""Metric that returns a value representing the number of times it has been updated."""

def __init__(self, name="update_count_metric", **kwargs):
super().__init__(name=name, **kwargs)
self._built = False

def update_state(self, y_true, y_pred, sample_weight=None):
if not self._built:
self.call_count = self.add_weight(
"call_count", shape=tf.TensorShape([1]), initializer="zeros"
)
self._built = True

self.call_count.assign(self.call_count + tf.constant([1.0]))

def result(self):
return self.call_count[0]

def reset_state(self):
self.call_count.assign(tf.constant([0.0]))


@pytest.mark.parametrize(
["num_rows", "batch_size", "train_metrics_steps", "expected_steps", "expected_metrics_steps"],
[
Expand All @@ -140,7 +163,7 @@ def test_train_metrics_steps(
model.compile(
run_eagerly=True,
optimizer="adam",
metrics=[tf.keras.metrics.AUC(from_logits=True, name="auc")],
metrics=[UpdateCountMetric()],
)
metrics_callback = MetricsLogger()
callbacks = [metrics_callback]
Expand All @@ -157,7 +180,9 @@ def test_train_metrics_steps(
assert len(epoch0_logs) == expected_steps

# number of times metrics computed (every train_metrics_steps batches)
assert len({metrics["auc"] for metrics in epoch0_logs}) == expected_metrics_steps
assert (
len({metrics["update_count_metric"] for metrics in epoch0_logs}) == expected_metrics_steps
)


@pytest.mark.parametrize("run_eagerly", [True, False])
Expand Down

0 comments on commit 8fa182f

Please sign in to comment.