diff --git a/official/benchmark/tfhub_memory_usage_benchmark.py b/official/benchmark/tfhub_memory_usage_benchmark.py index 2c0a2fceef9..2dda28cfcf1 100644 --- a/official/benchmark/tfhub_memory_usage_benchmark.py +++ b/official/benchmark/tfhub_memory_usage_benchmark.py @@ -16,6 +16,7 @@ Loads a SavedModel and records memory usage. """ +import functools import time from absl import flags @@ -31,24 +32,31 @@ class TfHubMemoryUsageBenchmark(PerfZeroBenchmark): """A benchmark measuring memory usage for a given TF Hub SavedModel.""" def __init__(self, + hub_model_handle_list=None, output_dir=None, default_flags=None, root_data_dir=None, **kwargs): super(TfHubMemoryUsageBenchmark, self).__init__( output_dir=output_dir, default_flags=default_flags, **kwargs) - - def benchmark_memory_usage(self): + if hub_model_handle_list: + for hub_model_handle in hub_model_handle_list.split(';'): + setattr( + self, 'benchmark_' + hub_model_handle, + functools.partial(self.benchmark_memory_usage, hub_model_handle)) + + def benchmark_memory_usage( + self, hub_model_handle='https://tfhub.dev/google/nnlm-en-dim128/1'): start_time_sec = time.time() - self.load_model() + self.load_model(hub_model_handle) wall_time_sec = time.time() - start_time_sec metrics = [] self.report_benchmark(iters=-1, wall_time=wall_time_sec, metrics=metrics) - def load_model(self): + def load_model(self, hub_model_handle): """Loads a TF Hub module.""" - hub.load('https://tfhub.dev/google/nnlm-en-dim128/1') + hub.load(hub_model_handle) if __name__ == '__main__':