From d4e065cc11184e365f466b90b88c1e00299b69e3 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Mon, 6 Jan 2025 12:29:45 +0100 Subject: [PATCH] Add type hints/simplify kernel_theoretical_timing Adding type hints allowed to simplify `kernel_theoretical_timing`. --- ndsl/dsl/dace/utils.py | 48 ++++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/ndsl/dsl/dace/utils.py b/ndsl/dsl/dace/utils.py index 7efd8f2..05fa575 100644 --- a/ndsl/dsl/dace/utils.py +++ b/ndsl/dsl/dace/utils.py @@ -18,7 +18,7 @@ class DaCeProgress: """Rough timer & log for major operations of DaCe build stack.""" - def __init__(self, config: DaceConfig, label: str): + def __init__(self, config: DaceConfig, label: str) -> None: self.prefix = DaCeProgress.default_prefix(config) self.label = label @@ -26,11 +26,11 @@ def __init__(self, config: DaceConfig, label: str): def default_prefix(cls, config: DaceConfig) -> str: return f"[{config.get_orchestrate()}]" - def __enter__(self): + def __enter__(self) -> None: ndsl_log.debug(f"{self.prefix} {self.label}...") self.start = time.time() - def __exit__(self, _type, _val, _traceback): + def __exit__(self, _type, _val, _traceback) -> None: elapsed = time.time() - self.start ndsl_log.debug(f"{self.prefix} {self.label}...{elapsed}s.") @@ -133,7 +133,7 @@ def memory_static_analysis( def report_memory_static_analysis( sdfg: dace.sdfg.SDFG, allocations: Dict[dace.StorageType, StorageReport], - detail_report=False, + detail_report: bool = False, ) -> str: """Create a human readable report form the memory analysis results""" report = f"{sdfg.name}:\n" @@ -168,7 +168,9 @@ def report_memory_static_analysis( return report -def memory_static_analysis_from_path(sdfg_path: str, detail_report=False) -> str: +def memory_static_analysis_from_path( + sdfg_path: str, detail_report: bool = False +) -> str: """Open a SDFG and report the memory analysis""" sdfg = dace.SDFG.from_file(sdfg_path) return report_memory_static_analysis( @@ -181,7 +183,7 @@ def memory_static_analysis_from_path(sdfg_path: str, detail_report=False) -> str # ---------------------------------------------------------- # Theoretical bandwidth from SDFG # ---------------------------------------------------------- -def copy_kernel(q_in: FloatField, q_out: FloatField): +def copy_kernel(q_in: FloatField, q_out: FloatField) -> None: with computation(PARALLEL), interval(...): q_in = q_out @@ -203,15 +205,15 @@ def __init__(self, size, backend) -> None: ) orchestrate(obj=self, config=dace_config) - def __call__(self, A, B, n: int): + def __call__(self, A, B, n: int) -> None: for i in dace.nounroll(range(n)): self.copy_stencil(A, B) def kernel_theoretical_timing( sdfg: dace.sdfg.SDFG, - hardware_bw_in_GB_s=None, - backend=None, + hardware_bw_in_GB_s: Optional[float] = None, + backend: Optional[str] = None, ) -> Dict[str, float]: """Compute a lower timing bound for kernels with the following hypothesis: @@ -221,7 +223,7 @@ def kernel_theoretical_timing( - Memory pressure is mostly in read/write from global memory, inner scalar & shared memory is not counted towards memory movement. """ - if not hardware_bw_in_GB_s: + if hardware_bw_in_GB_s is None: size = np.array(sdfg.arrays["__g_self__w"].shape) print( f"Calculating experimental hardware bandwidth on {size}" @@ -246,13 +248,19 @@ def kernel_theoretical_timing( bench(A, B, n) dt.append((time.time() - s) / n) memory_size_in_b = np.prod(size) * np.dtype(Float).itemsize * 8 - bandwidth_in_bytes_s = memory_size_in_b / np.median(dt) - print( - f"Hardware bandwidth computed: {bandwidth_in_bytes_s/(1024*1024*1024)} GB/s" - ) - else: - bandwidth_in_bytes_s = hardware_bw_in_GB_s * 1024 * 1024 * 1024 - print(f"Given hardware bandwidth: {bandwidth_in_bytes_s/(1024*1024*1024)} GB/s") + measured_bandwidth_in_bytes_s = memory_size_in_b / np.median(dt) + + bandwidth_in_bytes_s = ( + measured_bandwidth_in_bytes_s + if hardware_bw_in_GB_s is None + else hardware_bw_in_GB_s * 1024 * 1024 * 1024 + ) + label = ( + "Hardware bandwidth computed" + if hardware_bw_in_GB_s + else "Given hardware bandwidth" + ) + print(f"{label}: {bandwidth_in_bytes_s/(1024*1024*1024)} GB/s") allmaps = [ (me, state) @@ -305,12 +313,6 @@ def kernel_theoretical_timing( except TypeError: pass - # Bad expansion - if not isinstance(newresult_in_us, sympy.core.numbers.Float) and not isinstance( - newresult_in_us, float - ): - continue - result[node.label] = float(newresult_in_us) return result