Skip to content

Commit

Permalink
improve bashi-validate
Browse files Browse the repository at this point in the history
- add alternative short names for parameter
- add parameters in the order, which are passed as bashi-validate arguments
- extend print_nice(), that it prints a row in a shape, which can directly passed as parameters to bashi-validate
  • Loading branch information
SimeonEhrig committed Mar 18, 2024
1 parent 1ed3afa commit f614cd2
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 51 deletions.
2 changes: 1 addition & 1 deletion bashi/filter_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def compiler_filter(
bool: True, if parameter-value-tuple is valid.
"""
# uncomment me for debugging
# print_row_nice(row)
# print_row_nice(row, bashi_validate=False)

# Rule: c1
# NVCC as HOST_COMPILER is not allow
Expand Down
42 changes: 26 additions & 16 deletions bashi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@
)
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import

# short names for parameter
PARAMETER_SHORT_NAME: dict[Parameter, str] = {
HOST_COMPILER: "host",
DEVICE_COMPILER: "device",
ALPAKA_ACC_CPU_B_OMP2_T_SEQ_ENABLE: "bOpenMP2thread",
ALPAKA_ACC_CPU_B_SEQ_T_OMP2_ENABLE: "bOpenMP2block",
ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE: "bSeq",
ALPAKA_ACC_CPU_B_SEQ_T_THREADS_ENABLE: "bThreads",
ALPAKA_ACC_CPU_B_TBB_T_SEQ_ENABLE: "bTBB",
ALPAKA_ACC_GPU_CUDA_ENABLE: "bCUDA",
ALPAKA_ACC_GPU_HIP_ENABLE: "bHIP",
ALPAKA_ACC_SYCL_ENABLE: "bSYCL",
CXX_STANDARD: "c++",
}


@dataclasses.dataclass
class FilterAdapter:
Expand Down Expand Up @@ -371,40 +386,35 @@ def reason(output: Optional[IO[str]], msg: str):


# do not cover code, because the function is only used for debugging
def print_row_nice(row: ParameterValueTuple, init: str = ""): # pragma: no cover
def print_row_nice(
row: ParameterValueTuple, init: str = "", bashi_validate: bool = False
): # pragma: no cover
"""Prints a parameter-value-tuple in a short and nice way.
Args:
row (ParameterValueTuple): row with parameter-value-tuple
init (str, optional): Prefix of the output string. Defaults to "".
bashi_validate (bool): If it is set to True, the row is printed in a form that can be passed
directly as arguments to bashi-validate. Defaults to False.
"""
s = init
short_name: dict[str, str] = {
HOST_COMPILER: "host",
DEVICE_COMPILER: "device",
ALPAKA_ACC_CPU_B_OMP2_T_SEQ_ENABLE: "bOpenMP2thread",
ALPAKA_ACC_CPU_B_SEQ_T_OMP2_ENABLE: "bOpenMP2block",
ALPAKA_ACC_CPU_B_SEQ_T_SEQ_ENABLE: "bSeq",
ALPAKA_ACC_CPU_B_SEQ_T_THREADS_ENABLE: "bThreads",
ALPAKA_ACC_CPU_B_TBB_T_SEQ_ENABLE: "bTBB",
ALPAKA_ACC_GPU_CUDA_ENABLE: "bCUDA",
ALPAKA_ACC_GPU_HIP_ENABLE: "bHIP",
ALPAKA_ACC_SYCL_ENABLE: "bSYCL",
CXX_STANDARD: "c++",
}

nice_version: dict[packaging.version.Version, str] = {
ON_VER: "ON",
OFF_VER: "OFF",
}

for param, val in row.items():
parameter_prefix = "" if not bashi_validate else "--"
if param in [HOST_COMPILER, DEVICE_COMPILER]:
s += (
f"{short_name.get(param, param)}={short_name.get(val.name, val.name)}-"
f"{parameter_prefix}{PARAMETER_SHORT_NAME.get(param, param)}="
f"{PARAMETER_SHORT_NAME.get(val.name, val.name)}@"
f"{nice_version.get(val.version, str(val.version))} "
)
else:
s += (
f"{short_name.get(param, param)}={nice_version.get(val.version, str(val.version))} "
f"{parameter_prefix}{PARAMETER_SHORT_NAME.get(param, param)}="
f"{nice_version.get(val.version, str(val.version))} "
)
print(s)
117 changes: 83 additions & 34 deletions bashi/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import argparse
from argparse import ArgumentParser, Namespace

from typing import Sequence, Any, Callable, Optional, IO
from typing import Sequence, Any, Callable, Optional, IO, Dict, NamedTuple
from collections import OrderedDict
import io
import sys
Expand All @@ -15,10 +15,15 @@
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.types import ParameterValue, ParameterValueTuple
from bashi.versions import is_supported_version
from bashi.utils import PARAMETER_SHORT_NAME
import bashi.filter_compiler
import bashi.filter_backend
import bashi.filter_software_dependency

ArgumentAlias = NamedTuple("ArgumentAlias", [("alias", List[str]), ("parameter", Parameter)])
# stores the ordering of the parameter arguments
param_order: List[str] = []


@typechecked
def cs(text: str, color: str) -> str:
Expand Down Expand Up @@ -88,6 +93,8 @@ def __call__(
try:
parsed_version = pkv.parse(version)
setattr(namespace, self.dest, parsed_version)
if option_string:
param_order.append(option_string)
except packaging.version.InvalidVersion:
exit_error(f"Could not parse version of argument {option_string}: {version}")

Expand Down Expand Up @@ -124,27 +131,73 @@ def __call__(
try:
parsed_version = pkv.parse(version)
setattr(namespace, self.dest, ParameterValue(name, parsed_version))
if option_string:
param_order.append(option_string)
except packaging.version.InvalidVersion:
exit_error(f"Could not parse version number of {name}: {version}")


def get_args() -> Namespace:
def get_args(args_alias: Dict[str, ArgumentAlias]) -> Namespace:
"""Set up command line arguments and parsed it.
Returns:
Namespace: The parsed command line arguments
"""
parser = argparse.ArgumentParser(description="Check if combination of parameters is valid.")

def add_param_alias(argument: str, args_alias: Dict[str, ArgumentAlias]) -> List[str]:
"""Returns the argument name and also an alias, if it is defined in the PARAMETER_SHORT_NAME
Args:
argument (str): Name of the argument without '--' prefix
args_alias (Dict[str, ArgumentAlias]): Stores the alias and it's parameter for an
argument
Raises:
ValueError: If parameter is unknown
Returns:
List[str]: List of arguments for argparse
"""
argument_alias = [f"--{argument}"]
modified_arg = argument
if argument == "host-compiler":
modified_arg = HOST_COMPILER

if argument == "device-compiler":
modified_arg = DEVICE_COMPILER

if argument == "cxx":
modified_arg = CXX_STANDARD

if not modified_arg in (
HOST_COMPILER,
DEVICE_COMPILER,
*BACKENDS,
UBUNTU,
CMAKE,
BOOST,
CXX_STANDARD,
):
raise ValueError(f"{modified_arg} is not a know Parameter")

if modified_arg in PARAMETER_SHORT_NAME:
argument_alias.append(f"--{PARAMETER_SHORT_NAME[modified_arg]}")

# argparse also replace the '-' with the '_' if it stores the argument
args_alias[argument.replace("-", "_")] = ArgumentAlias(argument_alias, modified_arg)

return argument_alias

parser.add_argument(
"--host-compiler",
*add_param_alias("host-compiler", args_alias),
type=str,
action=CompilerVersionCheck,
help="Define host compiler. Shape needs to be name@version. " "For example gcc@10",
)

parser.add_argument(
"--device-compiler",
*add_param_alias("device-compiler", args_alias),
type=str,
action=CompilerVersionCheck,
help="Define device compiler. Shape needs to be name@version. " "For example [email protected]",
Expand All @@ -153,27 +206,32 @@ def get_args() -> Namespace:
for backend in BACKENDS:
if backend != ALPAKA_ACC_GPU_CUDA_ENABLE:
parser.add_argument(
"--" + backend,
*add_param_alias(backend, args_alias),
type=str,
action=VersionCheck,
choices=["ON", "OFF"],
help=f"Set backend {backend} as enabled or disabled.",
)
else:
parser.add_argument(
"--" + backend,
*add_param_alias(backend, args_alias),
type=str,
action=VersionCheck,
help=f"Set backend {backend} to disabled (OFF) or a specific CUDA SDK version.",
)

parser.add_argument("--ubuntu", type=str, action=VersionCheck, help="Ubuntu version.")

parser.add_argument("--cmake", type=str, action=VersionCheck, help="Set CMake version.")

parser.add_argument("--boost", type=str, action=VersionCheck, help="Set Boost version.")

parser.add_argument("--cxx", type=str, action=VersionCheck, help="C++ version.")
for argument, help_text in (
("ubuntu", "Ubuntu version."),
("cmake", "Set CMake version."),
("boost", "Set Boost version."),
("cxx", "C++ version."),
):
parser.add_argument(
*add_param_alias(argument, args_alias),
type=str,
action=VersionCheck,
help=help_text,
)

return parser.parse_args()

Expand Down Expand Up @@ -249,29 +307,21 @@ def check_filter_chain(row: ParameterValueTuple) -> bool:

def main() -> None:
"""Entry point for the application."""
args = get_args()
# stores alias for parameter arguments and it parameter itself
args_alias: Dict[str, ArgumentAlias] = {}
args = get_args(args_alias)

row: ParameterValueTuple = OrderedDict()

for param, arg in [
(HOST_COMPILER, "host_compiler"),
(DEVICE_COMPILER, "device_compiler"),
]:
if hasattr(args, arg) and getattr(args, arg) is not None:
row[param] = getattr(args, arg)

for backend in BACKENDS:
if hasattr(args, backend) and getattr(args, backend) is not None:
row[backend] = ParameterValue(backend, getattr(args, backend))

for param, arg in [
(UBUNTU, "ubuntu"),
(CMAKE, "cmake"),
(BOOST, "boost"),
(CXX_STANDARD, "cxx"),
]:
if hasattr(args, arg) and getattr(args, arg) is not None:
row[param] = ParameterValue(param, getattr(args, arg))
# Add parameter-values in the order in which they are passed via arguments
for param_arg in param_order:
for arg, alias in args_alias.items():
if param_arg in alias.alias:
if getattr(args, arg) is not None:
if arg in ("host_compiler", "device_compiler"):
row[alias.parameter] = getattr(args, arg)
else:
row[alias.parameter] = ParameterValue(alias.parameter, getattr(args, arg))

for val_name, val_version in row.values():
if not is_supported_version(val_name, val_version):
Expand All @@ -281,7 +331,6 @@ def main() -> None:
"Yellow",
)
)

sys.exit(int(not check_filter_chain(row)))


Expand Down

0 comments on commit f614cd2

Please sign in to comment.