diff --git a/gematria/datasets/pipelines/BUILD.bazel b/gematria/datasets/pipelines/BUILD.bazel index c7fbcdd6..861738a2 100644 --- a/gematria/datasets/pipelines/BUILD.bazel +++ b/gematria/datasets/pipelines/BUILD.bazel @@ -54,8 +54,12 @@ gematria_py_binary( srcs = ["benchmark_bbs_lib.py"], deps = [ ":benchmark_cpu_scheduler", + "//gematria/datasets/python:bhive_importer", "//gematria/datasets/python:exegesis_benchmark", + "//gematria/llvm/python:canonicalizer", + "//gematria/llvm/python:llvm_architecture_support", "//gematria/proto:execution_annotation_py_pb2", + "//gematria/proto:throughput_py_pb2", ], ) @@ -82,6 +86,7 @@ gematria_py_test( ":benchmark_cpu_scheduler", "//gematria/io/python:tfrecord", "//gematria/proto:execution_annotation_py_pb2", + "//gematria/proto:throughput_py_pb2", ], ) diff --git a/gematria/datasets/pipelines/benchmark_bbs_lib.py b/gematria/datasets/pipelines/benchmark_bbs_lib.py index eb3300cb..19581766 100644 --- a/gematria/datasets/pipelines/benchmark_bbs_lib.py +++ b/gematria/datasets/pipelines/benchmark_bbs_lib.py @@ -21,6 +21,10 @@ from gematria.proto import execution_annotation_pb2 from gematria.datasets.python import exegesis_benchmark from gematria.datasets.pipelines import benchmark_cpu_scheduler +from gematria.proto import throughput_pb2 +from gematria.llvm.python import canonicalizer +from gematria.llvm.python import llvm_architecture_support +from gematria.datasets.python import bhive_importer _BEAM_METRIC_NAMESPACE_NAME = 'benchmark_bbs' @@ -71,14 +75,21 @@ def process( pass -class FormatBBsForOutput(beam.DoFn): - """A Beam function for formatting hex/throughput values for output.""" +class SerializeToProto(beam.DoFn): + """A Beam function for formatting hex/throughput values to protos.""" + + def setup(self): + self._x86_llvm = llvm_architecture_support.LlvmArchitectureSupport.x86_64() + self._x86_canonicalizer = canonicalizer.Canonicalizer.x86_64(self._x86_llvm) + self._importer = bhive_importer.BHiveImporter(self._x86_canonicalizer) def process( self, block_hex_and_throughput: tuple[str, float] - ) -> Iterable[str]: + ) -> Iterable[throughput_pb2.BasicBlockWithThroughputProto]: block_hex, throughput = block_hex_and_throughput - yield f'{block_hex},{throughput}' + yield self._importer.block_with_throughput_from_hex_and_throughput( + 'pipeline', block_hex, throughput + ) def benchmark_bbs( @@ -99,12 +110,15 @@ def pipeline(root: beam.Pipeline) -> None: benchmarked_blocks = annotated_bbs_shuffled | 'Benchmarking' >> beam.ParDo( BenchmarkBasicBlock(benchmark_scheduler_type) ) - formatted_output = benchmarked_blocks | 'Formatting' >> beam.ParDo( - FormatBBsForOutput() + block_protos = benchmarked_blocks | 'Serialize to protos' >> beam.ParDo( + SerializeToProto() ) - _ = formatted_output | 'Write To Text' >> beam.io.WriteToText( - output_file_pattern + _ = block_protos | 'Write serialized blocks' >> beam.io.WriteToTFRecord( + output_file_pattern, + coder=beam.coders.ProtoCoder( + throughput_pb2.BasicBlockWithThroughputProto().__class__ + ), ) return pipeline diff --git a/gematria/datasets/pipelines/benchmark_bbs_lib_test.py b/gematria/datasets/pipelines/benchmark_bbs_lib_test.py index 2b869c1d..7c4abe8b 100644 --- a/gematria/datasets/pipelines/benchmark_bbs_lib_test.py +++ b/gematria/datasets/pipelines/benchmark_bbs_lib_test.py @@ -16,12 +16,12 @@ from absl.testing import absltest from apache_beam.testing import test_pipeline -from apache_beam.testing import util as beam_test from gematria.datasets.pipelines import benchmark_bbs_lib from gematria.proto import execution_annotation_pb2 from gematria.io.python import tfrecord from gematria.datasets.pipelines import benchmark_cpu_scheduler +from gematria.proto import throughput_pb2 BLOCK_FOR_TESTING = execution_annotation_pb2.BlockWithExecutionAnnotations( execution_annotations=execution_annotation_pb2.ExecutionAnnotations( @@ -60,14 +60,14 @@ def test_benchmark_basic_block(self): self.assertEqual(block_hex, '3b31') self.assertLess(block_throughput, 10) - def test_format_bbs(self): - format_transform = benchmark_bbs_lib.FormatBBsForOutput() + def test_serialize_bbs_to_protos(self): + serialize_transform = benchmark_bbs_lib.SerializeToProto() + serialize_transform.setup() benchmarked_block_data = ('3b31', 5) - output = list(format_transform.process(benchmarked_block_data)) + output = list(serialize_transform.process(benchmarked_block_data)) self.assertLen(output, 1) - self.assertEqual(output[0], '3b31,5') def test_benchmark_bbs(self): test_tfrecord = self.create_tempfile() @@ -85,13 +85,19 @@ def test_benchmark_bbs(self): with test_pipeline.TestPipeline() as pipeline_under_test: pipeline_constructor(pipeline_under_test) - with open(output_file_pattern + '-00000-of-00001') as output_txt_file: - output_lines = output_txt_file.readlines() - self.assertLen(output_lines, 1) - - line_parts = output_lines[0].split(',') - self.assertEqual(line_parts[0], '3b31') - self.assertLess(float(line_parts[1]), 10) + throughputs = [] + for block_with_throughput in tfrecord.read_protos( + [output_file_pattern + '-00000-of-00001'], + throughput_pb2.BasicBlockWithThroughputProto, + ): + throughputs.append( + block_with_throughput.inverse_throughputs[ + 0 + ].inverse_throughput_cycles[0] + ) + + self.assertLen(throughputs, 1) + self.assertLess(throughputs[0], 10) if __name__ == '__main__':