Skip to content

Commit

Permalink
Merge pull request #19 from SimeonEhrig/unifyCompilerFilter
Browse files Browse the repository at this point in the history
Unify compiler name and version filter
  • Loading branch information
SimeonEhrig authored Mar 13, 2024
2 parents 2f6b4dd + 75e45e9 commit 1d5c80b
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 252 deletions.
6 changes: 2 additions & 4 deletions bashi/filter_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from typeguard import typechecked
from bashi.types import FilterFunction

from bashi.filter_compiler_name import compiler_name_filter
from bashi.filter_compiler_version import compiler_version_filter
from bashi.filter_compiler import compiler_filter
from bashi.filter_backend import backend_filter
from bashi.filter_software_dependency import software_dependency_filter

Expand All @@ -25,8 +24,7 @@ def get_default_filter_chain(
FilterFunction: The filter function chain, which can be directly used in bashi.FilterAdapter
"""
return (
lambda row: compiler_name_filter(row)
and compiler_version_filter(row)
lambda row: compiler_filter(row)
and backend_filter(row)
and software_dependency_filter(row)
and custom_filter_function(row)
Expand Down
78 changes: 52 additions & 26 deletions bashi/filter_compiler_version.py → bashi/filter_compiler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Filter rules basing on host and device compiler names and versions.
All rules implemented in this filter have an identifier that begins with "v" and follows a number.
Examples: v1, v42, v678 ...
All rules implemented in this filter have an identifier that begins with "c" and follows a number.
Examples: c1, c42, c678 ...
These identifiers are used in the test names, for example, to make it clear which test is testing
which rule.
Expand All @@ -15,6 +15,9 @@
from bashi.versions import NVCC_GCC_MAX_VERSION, NVCC_CLANG_MAX_VERSION
from bashi.utils import reason

# uncomment me for debugging
# from bashi.utils import print_row_nice


def get_required_parameters() -> List[Parameter]:
"""Return list of parameters which will be checked in the filter.
Expand All @@ -26,18 +29,19 @@ def get_required_parameters() -> List[Parameter]:


@typechecked
def compiler_version_filter_typechecked(
def compiler_filter_typechecked(
row: ParameterValueTuple,
output: Optional[IO[str]] = None,
) -> bool:
"""Type-checked version of compiler_version_filter(). Type checking has a big performance cost,
which is why the non type-checked version is used for the pairwise generator.
"""Type-checked version of compiler_filter(). Type checking has a big performance cost, which
is why the non type-checked version is used for the pairwise generator.
"""
return compiler_version_filter(row, output)
return compiler_filter(row, output)


# pylint: disable=too-many-branches
def compiler_version_filter(
# pylint: disable=too-many-return-statements
def compiler_filter(
row: ParameterValueTuple,
output: Optional[IO[str]] = None,
) -> bool:
Expand All @@ -52,22 +56,43 @@ def compiler_version_filter(
Returns:
bool: True, if parameter-value-tuple is valid.
"""
# uncomment me for debugging
# print_row_nice(row)

# Rule: v1
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER].name != NVCC
and HOST_COMPILER in row
and row[HOST_COMPILER].version != row[DEVICE_COMPILER].version
):
reason(output, "host and device compiler version must be the same (except for nvcc)")
# Rule: c1
# NVCC as HOST_COMPILER is not allow
# this rule will be never used, because of an implementation detail of the covertable library
# it is not possible to add NVCC as HOST_COMPILER and filter out afterwards
# this rule is only used by bashi-verify
if HOST_COMPILER in row and row[HOST_COMPILER].name == NVCC:
reason(output, "nvcc is not allowed as host compiler")
return False

if HOST_COMPILER in row and DEVICE_COMPILER in row:
if NVCC in (row[HOST_COMPILER].name, row[DEVICE_COMPILER].name):
# Rule: c2
if row[HOST_COMPILER].name not in (GCC, CLANG):
reason(output, "only gcc and clang are allowed as nvcc host compiler")
return False
else:
# Rule: c3
if row[HOST_COMPILER].name != row[DEVICE_COMPILER].name:
reason(output, "host and device compiler name must be the same (except for nvcc)")
return False

# Rule: c4
if row[HOST_COMPILER].version != row[DEVICE_COMPILER].version:
reason(
output,
"host and device compiler version must be the same (except for nvcc)",
)
return False

# now idea, how remove nested blocks without hitting the performance
# pylint: disable=too-many-nested-blocks
if DEVICE_COMPILER in row and row[DEVICE_COMPILER].name == NVCC:
if HOST_COMPILER in row and row[HOST_COMPILER].name == GCC:
# Rule: v2
# Rule: c5
# remove all unsupported nvcc gcc version combinations
# define which is the latest supported gcc compiler for a nvcc version

Expand All @@ -87,7 +112,7 @@ def compiler_version_filter(
break

if HOST_COMPILER in row and row[HOST_COMPILER].name == CLANG:
# Rule: v4
# Rule: c7
if row[DEVICE_COMPILER].version >= pkv.parse("11.3") and row[
DEVICE_COMPILER
].version <= pkv.parse("11.5"):
Expand All @@ -97,7 +122,7 @@ def compiler_version_filter(
)
return False

# Rule: v3
# Rule: c6
# remove all unsupported nvcc clang version combinations
# define which is the latest supported clang compiler for a nvcc version

Expand All @@ -116,17 +141,18 @@ def compiler_version_filter(
return False
break

# Rule: v5
# Rule: c8
# clang-cuda 13 and older is not supported
# this rule will be never used, because of an implementation detail of the covertable library
# it is not possible to add the clang-cuda versions and filter it out afterwards
# this rule is only used by bashi-verify
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER].name == CLANG_CUDA
and row[DEVICE_COMPILER].version < pkv.parse("14")
):
reason(output, "all clang versions older than 14 are disabled as CUDA Compiler")
return False
for compiler in (HOST_COMPILER, DEVICE_COMPILER):
if (
compiler in row
and row[compiler].name == CLANG_CUDA
and row[compiler].version < pkv.parse("14")
):
reason(output, "all clang versions older than 14 are disabled as CUDA Compiler")
return False

return True
81 changes: 0 additions & 81 deletions bashi/filter_compiler_name.py

This file was deleted.

38 changes: 2 additions & 36 deletions bashi/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from typing import Dict, List
from collections import OrderedDict
import copy
import packaging.version as pkv

from covertable import make # type: ignore

Expand Down Expand Up @@ -34,44 +32,12 @@ def generate_combination_list(
Returns:
CombinationList: combination-list
"""
# use local version to do not modify parameter_value_matrix
local_param_val_mat = copy.deepcopy(parameter_value_matrix)

filter_chain = get_default_filter_chain(custom_filter)

def host_compiler_filter(param_val: ParameterValue) -> bool:
# Rule: n1
# remove nvcc as host compiler
if param_val.name == NVCC:
return False
# Rule: v5
# remove clang-cuda older than 14
if param_val.name == CLANG_CUDA and param_val.version < pkv.parse("14"):
return False

return True

def device_compiler_filter(param_val: ParameterValue) -> bool:
# Rule: v5
# remove clang-cuda older than 14
if param_val.name == CLANG_CUDA and param_val.version < pkv.parse("14"):
return False

return True

pre_filters = {HOST_COMPILER: host_compiler_filter, DEVICE_COMPILER: device_compiler_filter}

# some filter rules requires that specific parameter-values are already removed from the
# parameter-value-matrix
# otherwise the covertable library throws an error
for param, filter_func in pre_filters.items():
if param in local_param_val_mat:
local_param_val_mat[param] = list(filter(filter_func, local_param_val_mat[param]))

comb_list: CombinationList = []

all_pairs: List[Dict[Parameter, ParameterValue]] = make(
factors=local_param_val_mat,
factors=parameter_value_matrix,
length=2,
pre_filter=filter_chain,
) # type: ignore
Expand All @@ -81,7 +47,7 @@ def device_compiler_filter(param_val: ParameterValue) -> bool:
tmp_comb: Combination = OrderedDict({})
# covertable does not keep the ordering of the parameters
# therefore we sort it
for param in local_param_val_mat.keys():
for param in parameter_value_matrix.keys():
tmp_comb[param] = all_pair[param]
comb_list.append(tmp_comb)

Expand Down
40 changes: 40 additions & 0 deletions bashi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,43 @@ def reason(output: Optional[IO[str]], msg: str):
file=output,
end="",
)


# do not cover code, because the function is only used for debugging
def print_row_nice(row: ParameterValueTuple, init: str = ""): # pragma: no cover
"""Prints a parameter-value-tuple in a short and nice way.
Args:
row (ParameterValueTuple): row with parameter-value-tuple
init (str, optional): Prefix of the output string. Defaults to "".
"""
s = init
short_name: dict[str, str] = {
HOST_COMPILER: "host",
DEVICE_COMPILER: "device",
ALPAKA_ACC_CPU_B_OMP2_T_SEQ_ENABLE: "bOpenMP2thread",
ALPAKA_ACC_CPU_B_SEQ_T_OMP2_ENABLE: "bOpenMP2block",
ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE: "bSeq",
ALPAKA_ACC_CPU_B_SEQ_T_THREADS_ENABLE: "bThreads",
ALPAKA_ACC_CPU_B_TBB_T_SEQ_ENABLE: "bTBB",
ALPAKA_ACC_GPU_CUDA_ENABLE: "bCUDA",
ALPAKA_ACC_GPU_HIP_ENABLE: "bHIP",
ALPAKA_ACC_SYCL_ENABLE: "bSYCL",
CXX_STANDARD: "c++",
}
nice_version: dict[packaging.version.Version, str] = {
ON_VER: "ON",
OFF_VER: "OFF",
}

for param, val in row.items():
if param in [HOST_COMPILER, DEVICE_COMPILER]:
s += (
f"{short_name.get(param, param)}={short_name.get(val.name, val.name)}-"
f"{nice_version.get(val.version, str(val.version))} "
)
else:
s += (
f"{short_name.get(param, param)}={nice_version.get(val.version, str(val.version))} "
)
print(s)
17 changes: 4 additions & 13 deletions bashi/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.types import ParameterValue, ParameterValueTuple
from bashi.versions import is_supported_version
import bashi.filter_compiler_name
import bashi.filter_compiler_version
import bashi.filter_compiler
import bashi.filter_backend
import bashi.filter_software_dependency

Expand Down Expand Up @@ -244,17 +243,9 @@ def check_filter_chain(row: ParameterValueTuple) -> bool:
all_true = 0
all_true += int(
check_single_filter(
bashi.filter_compiler_name.compiler_name_filter_typechecked,
bashi.filter_compiler.compiler_filter,
row,
bashi.filter_compiler_name.get_required_parameters(),
)
)

all_true += int(
check_single_filter(
bashi.filter_compiler_version.compiler_version_filter_typechecked,
row,
bashi.filter_compiler_version.get_required_parameters(),
bashi.filter_compiler.get_required_parameters(),
)
)
all_true += int(
Expand All @@ -273,7 +264,7 @@ def check_filter_chain(row: ParameterValueTuple) -> bool:
)

# each filter add a one, if it was successful
return all_true == 4
return all_true == 3


def main() -> None:
Expand Down
Loading

0 comments on commit 1d5c80b

Please sign in to comment.