From e97d04a406aa9d4e234a2ed3bd085368f87a24fa Mon Sep 17 00:00:00 2001 From: Simeon Ehrig Date: Tue, 13 Feb 2024 13:46:37 +0100 Subject: [PATCH] add filter rule which allows only gcc and clang as nvcc host compiler --- bashi/filter_compiler_name.py | 10 +++ bashi/utils.py | 50 +++++++++++--- tests/test_generate_combination_list.py | 7 +- tests/test_nvcc_filter.py | 92 ++++++++++++++++++++++++- 4 files changed, 147 insertions(+), 12 deletions(-) diff --git a/bashi/filter_compiler_name.py b/bashi/filter_compiler_name.py index 0aa70a0..d57f10f 100644 --- a/bashi/filter_compiler_name.py +++ b/bashi/filter_compiler_name.py @@ -58,4 +58,14 @@ def compiler_name_filter( reason(output, "nvcc is not allowed as host compiler") return False + # Rule: n2 + if ( + DEVICE_COMPILER in row + and row[DEVICE_COMPILER].name == NVCC + and HOST_COMPILER in row + and not (row[HOST_COMPILER].name == GCC or row[HOST_COMPILER].name == CLANG) + ): + reason(output, "only gcc and clang are allowed as nvcc host compiler") + return False + return True diff --git a/bashi/utils.py b/bashi/utils.py index 4faf81d..50f3793 100644 --- a/bashi/utils.py +++ b/bashi/utils.py @@ -1,21 +1,25 @@ """Different helper functions for bashi""" -from typing import Dict, List, IO, Union, Optional -from collections import OrderedDict import dataclasses import sys -from typeguard import typechecked +from collections import OrderedDict +from typing import IO, Dict, List, Optional, Union + import packaging.version +from typeguard import typechecked + from bashi.types import ( + CombinationList, + FilterFunction, Parameter, ParameterValue, - ParameterValueTuple, - ParameterValueSingle, - ParameterValuePair, ParameterValueMatrix, - CombinationList, - FilterFunction, + ParameterValuePair, + ParameterValueSingle, + ParameterValueTuple, ) +from bashi.versions import COMPILERS +from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import @dataclasses.dataclass @@ -126,7 +130,7 @@ def create_parameter_value_pair( # pylint: disable=too-many-arguments def get_expected_parameter_value_pairs( parameter_matrix: ParameterValueMatrix, ) -> List[ParameterValuePair]: - """Takes parameter-value-matrix an creates a list of all expected parameter-values-pairs. + """Takes parameter-value-matrix and creates a list of all expected parameter-values-pairs. The pair-wise generator guaranties, that each pair of two parameter-values exist in at least one combination if no filter rules exist. Therefore the generated the generated list can be used to verify the output of the pair-wise generator. @@ -310,3 +314,31 @@ def reason(output: Optional[IO[str]], msg: str): file=output, end="", ) + + +@typechecked +def get_expected_bashi_parameter_value_pairs( + parameter_matrix: ParameterValueMatrix, +) -> List[ParameterValuePair]: + """Takes parameter-value-matrix and creates a list of all expected parameter-values-pairs + allowed by the bashi library. First it generates a complete list of parameter-value-pairs and + then it removes all pairs that are not allowed by filter rules. + + Args: + parameter_matrix (ParameterValueMatrix): matrix of parameter values + + Returns: + List[ParameterValuePair]: list of all parameter-value-pairs supported by bashi + """ + param_val_pair_list = get_expected_parameter_value_pairs(parameter_matrix) + + for compiler_name in set(COMPILERS) - set([GCC, CLANG, NVCC]): + remove_parameter_value_pair( + to_remove=create_parameter_value_pair( + HOST_COMPILER, compiler_name, 0, DEVICE_COMPILER, NVCC, 0 + ), + parameter_value_pairs=param_val_pair_list, + all_versions=True, + ) + + return param_val_pair_list diff --git a/tests/test_generate_combination_list.py b/tests/test_generate_combination_list.py index 29b3b66..91ebf3c 100644 --- a/tests/test_generate_combination_list.py +++ b/tests/test_generate_combination_list.py @@ -9,6 +9,7 @@ from bashi.generator import generate_combination_list from bashi.utils import ( get_expected_parameter_value_pairs, + get_expected_bashi_parameter_value_pairs, check_parameter_value_pair_in_combination_list, remove_parameter_value_pair, create_parameter_value_pair, @@ -182,7 +183,7 @@ def custom_filter(row: ParameterValueTuple) -> bool: class TestGeneratorRealData(unittest.TestCase): def test_generator_without_custom_filter(self): param_val_matrix = get_parameter_value_matrix() - expected_param_val_pairs = get_expected_parameter_value_pairs(param_val_matrix) + expected_param_val_pairs = get_expected_bashi_parameter_value_pairs(param_val_matrix) comb_list = generate_combination_list(param_val_matrix) @@ -203,7 +204,9 @@ def custom_filter(row: ParameterValueTuple) -> bool: return True param_val_matrix = get_parameter_value_matrix() - reduced_expected_param_val_pairs = get_expected_parameter_value_pairs(param_val_matrix) + reduced_expected_param_val_pairs = get_expected_bashi_parameter_value_pairs( + param_val_matrix + ) self.assertTrue( remove_parameter_value_pair( diff --git a/tests/test_nvcc_filter.py b/tests/test_nvcc_filter.py index a1b0e6b..9bce897 100644 --- a/tests/test_nvcc_filter.py +++ b/tests/test_nvcc_filter.py @@ -8,7 +8,7 @@ from bashi.filter_compiler_name import compiler_name_filter_typechecked -class TestNvccHostCompilerFilter(unittest.TestCase): +class TestNoNvccHostCompiler(unittest.TestCase): def test_valid_combination_rule_n1(self): self.assertTrue( compiler_name_filter_typechecked( @@ -79,3 +79,93 @@ def test_reason_rule_n1(self): compiler_name_filter_typechecked(OD({HOST_COMPILER: ppv((NVCC, 10.2))}), reason_msg) ) self.assertEqual(reason_msg.getvalue(), "nvcc is not allowed as host compiler") + + +class TestSupportedNvccHostCompiler(unittest.TestCase): + def test_invalid_combination_rule_n2(self): + for compiler_name in [CLANG_CUDA, HIPCC, ICPX, NVCC]: + for compiler_version in ["0", "13", "32a2"]: + reason_msg = io.StringIO() + self.assertFalse( + compiler_name_filter_typechecked( + OD( + { + HOST_COMPILER: ppv((compiler_name, compiler_version)), + DEVICE_COMPILER: ppv((NVCC, "12.3")), + } + ), + reason_msg, + ) + ) + # NVCC is filtered by rule n1 + if compiler_name != NVCC: + self.assertEqual( + reason_msg.getvalue(), + "only gcc and clang are allowed as nvcc host compiler", + ) + + self.assertFalse( + compiler_name_filter_typechecked( + OD( + { + HOST_COMPILER: ppv((HIPCC, "5.3")), + DEVICE_COMPILER: ppv((NVCC, "12.3")), + CMAKE: ppv((CMAKE, "3.18")), + BOOST: ppv((BOOST, "1.81.0")), + } + ) + ) + ) + self.assertFalse( + compiler_name_filter_typechecked( + OD( + { + HOST_COMPILER: ppv((HIPCC, "5.3")), + ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE: ppv( + (ALPAKA_ACC_CPU_B_TBB_T_SEQ_ENABLE, "1.0.0") + ), + DEVICE_COMPILER: ppv((NVCC, "12.3")), + } + ) + ) + ) + + def test_valid_combination_rule_n2(self): + for compiler_name in [GCC, CLANG]: + for compiler_version in ["0", "13", "7b2"]: + self.assertTrue( + compiler_name_filter_typechecked( + OD( + { + HOST_COMPILER: ppv((compiler_name, compiler_version)), + DEVICE_COMPILER: ppv((NVCC, "12.3")), + } + ) + ) + ) + + self.assertTrue( + compiler_name_filter_typechecked( + OD( + { + HOST_COMPILER: ppv((GCC, "13")), + DEVICE_COMPILER: ppv((NVCC, "11.5")), + BOOST: ppv((BOOST, "1.84.0")), + CMAKE: ppv((CMAKE, "3.23")), + } + ) + ) + ) + self.assertTrue( + compiler_name_filter_typechecked( + OD( + { + HOST_COMPILER: ppv((CLANG, "14")), + ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE: ppv( + (ALPAKA_ACC_CPU_B_TBB_T_SEQ_ENABLE, "1.0.0") + ), + DEVICE_COMPILER: ppv((NVCC, "10.1")), + } + ) + ) + )