-
Notifications
You must be signed in to change notification settings - Fork 315
/
Copy pathtest_profiler.py
80 lines (66 loc) · 2.84 KB
/
test_profiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Profiler tests for TPUs."""
import glob
import json
import os
import unittest
from tensorboard_plugin_profile.convert import raw_to_tool_data
class ProfilerTest(unittest.TestCase):
"""Test for profile collected with JAX."""
def _get_session_snapshot(self):
"""Gets a session snapshot of current session. assume only one session."""
profile_plugin_root = "tensorboard/plugins/profile"
# The session exists under a director whose name is time-dependent.
profile_session_glob = os.path.join(profile_plugin_root, "*", "*.xplane.pb")
return glob.glob(profile_session_glob)
def test_xplane_is_present(self):
files = self._get_session_snapshot()
self.assertEqual(len(files), 1)
def test_overview_page(self):
xspace_filenames = self._get_session_snapshot()
result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, "overview_page^", {})
result = json.loads(result)
run_environment = result[2]
self.assertEqual(run_environment["p"]["host_count"], "1")
self.assertRegex(run_environment["p"]["device_type"], "TPU.*")
def test_op_profile(self):
xspace_filenames = self._get_session_snapshot()
result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, "op_profile^", {})
result = json.loads(result)
self.assertIn("byCategory", result)
self.assertIn("metrics", result["byCategory"])
overall_metrics = result["byCategory"]["metrics"]
self.assertIn("flops", overall_metrics)
self.assertIn("bandwidthUtils", overall_metrics)
self.assertGreater(overall_metrics["flops"], 0)
def test_device_trace_contains_threads(self):
xspace_filenames = self._get_session_snapshot()
result, _ = raw_to_tool_data.xspace_to_tool_data(xspace_filenames, "trace_viewer^", {})
result = json.loads(result)
thread_names = []
for event in result["traceEvents"]:
if "name" in event and event["name"] == "thread_name":
thread_names.append((event["args"]["name"]))
expected_threads = [
"Framework Name Scope",
"Framework Ops",
"XLA Modules",
"XLA Ops",
"XLA TraceMe",
"Steps",
]
# Ensure that thread_names contains at least all expected threads.
self.assertEqual(set(expected_threads) - set(thread_names), set())
if __name__ == "__main__":
unittest.main()