Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 298920502
  • Loading branch information
tensorflower-gardener committed Mar 4, 2020
1 parent 0066ae2 commit bb8a18c
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions official/benchmark/tfhub_memory_usage_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Loads a SavedModel and records memory usage.
"""
import functools
import time

from absl import flags
Expand All @@ -31,24 +32,31 @@ class TfHubMemoryUsageBenchmark(PerfZeroBenchmark):
"""A benchmark measuring memory usage for a given TF Hub SavedModel."""

def __init__(self,
hub_model_handle_list=None,
output_dir=None,
default_flags=None,
root_data_dir=None,
**kwargs):
super(TfHubMemoryUsageBenchmark, self).__init__(
output_dir=output_dir, default_flags=default_flags, **kwargs)

def benchmark_memory_usage(self):
if hub_model_handle_list:
for hub_model_handle in hub_model_handle_list.split(';'):
setattr(
self, 'benchmark_' + hub_model_handle,
functools.partial(self.benchmark_memory_usage, hub_model_handle))

def benchmark_memory_usage(
self, hub_model_handle='https://tfhub.dev/google/nnlm-en-dim128/1'):
start_time_sec = time.time()
self.load_model()
self.load_model(hub_model_handle)
wall_time_sec = time.time() - start_time_sec

metrics = []
self.report_benchmark(iters=-1, wall_time=wall_time_sec, metrics=metrics)

def load_model(self):
def load_model(self, hub_model_handle):
"""Loads a TF Hub module."""
hub.load('https://tfhub.dev/google/nnlm-en-dim128/1')
hub.load(hub_model_handle)


if __name__ == '__main__':
Expand Down

0 comments on commit bb8a18c

Please sign in to comment.