Skip to content

Commit

Permalink
Merge pull request #1374 from CEED/jrwrigh/fix_test_python
Browse files Browse the repository at this point in the history
fix(test): Use explicit `typing` objects for containers
  • Loading branch information
jrwrigh authored Oct 13, 2023
2 parents 19868e1 + 78cb100 commit a77d276
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
2 changes: 1 addition & 1 deletion tests/junit.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str)
return f'SYCL device type not available'
return None

def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> tuple[str, bool]:
def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]:
"""Check whether a test case is expected to fail and if it failed expectedly
Args:
Expand Down
44 changes: 22 additions & 22 deletions tests/junit_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from itertools import product
import sys
import time
from typing import Optional
from typing import Optional, Tuple, List

sys.path.insert(0, str(Path(__file__).parent / "junit-xml"))
from junit_xml import TestCase, TestSuite, to_xml_report_string # nopep8
Expand Down Expand Up @@ -133,7 +133,7 @@ def check_post_skip(self, test: str, spec: TestSpec, resource: str, stderr: str)
"""
return None

def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> tuple[str, bool]:
def check_required_failure(self, test: str, spec: TestSpec, resource: str, stderr: str) -> Tuple[str, bool]:
"""Check whether a test case is expected to fail and if it failed expectedly
Args:
Expand Down Expand Up @@ -174,7 +174,7 @@ def has_cgnsdiff() -> bool:
return 'not found' not in proc.stderr.decode('utf-8')


def contains_any(base: str, substrings: list[str]) -> bool:
def contains_any(base: str, substrings: List[str]) -> bool:
"""Helper function, checks if any of the substrings are included in the base string
Args:
Expand All @@ -187,7 +187,7 @@ def contains_any(base: str, substrings: list[str]) -> bool:
return any((sub in base for sub in substrings))


def startswith_any(base: str, prefixes: list[str]) -> bool:
def startswith_any(base: str, prefixes: List[str]) -> bool:
"""Helper function, checks if the base string is prefixed by any of `prefixes`
Args:
Expand All @@ -209,21 +209,21 @@ def parse_test_line(line: str) -> TestSpec:
Returns:
TestSpec: Parsed specification of test case
"""
args: list[str] = re.findall("(?:\".*?\"|\\S)+", line.strip())
args: List[str] = re.findall("(?:\".*?\"|\\S)+", line.strip())
if args[0] == 'TESTARGS':
return TestSpec(name='', args=args[1:])
raw_test_args: str = args[0][args[0].index('TESTARGS(') + 9:args[0].rindex(')')]
# transform 'name="myname",only="serial,int32"' into {'name': 'myname', 'only': 'serial,int32'}
test_args: dict = dict([''.join(t).split('=') for t in re.findall(r"""([^,=]+)(=)"([^"]*)\"""", raw_test_args)])
name: str = test_args.get('name', '')
constraints: list[str] = test_args['only'].split(',') if 'only' in test_args else []
constraints: List[str] = test_args['only'].split(',') if 'only' in test_args else []
if len(args) > 1:
return TestSpec(name=name, only=constraints, args=args[1:])
else:
return TestSpec(name=name, only=constraints)


def get_test_args(source_file: Path) -> list[TestSpec]:
def get_test_args(source_file: Path) -> List[TestSpec]:
"""Parse all test cases from a given source file
Args:
Expand Down Expand Up @@ -264,18 +264,18 @@ def diff_csv(test_csv: Path, true_csv: Path, zero_tol: float = 3e-10, rel_tol: f
Returns:
str: Diff output between result and expected CSVs
"""
test_lines: list[str] = test_csv.read_text().splitlines()
true_lines: list[str] = true_csv.read_text().splitlines()
test_lines: List[str] = test_csv.read_text().splitlines()
true_lines: List[str] = true_csv.read_text().splitlines()

if test_lines[0] != true_lines[0]:
return ''.join(difflib.unified_diff([f'{test_lines[0]}\n'], [f'{true_lines[0]}\n'],
tofile='found CSV columns', fromfile='expected CSV columns'))

diff_lines: list[str] = list()
column_names: list[str] = true_lines[0].strip().split(',')
diff_lines: List[str] = list()
column_names: List[str] = true_lines[0].strip().split(',')
for test_line, true_line in zip(test_lines[1:], true_lines[1:]):
test_vals: list[float] = [float(val.strip()) for val in test_line.strip().split(',')]
true_vals: list[float] = [float(val.strip()) for val in true_line.strip().split(',')]
test_vals: List[float] = [float(val.strip()) for val in test_line.strip().split(',')]
true_vals: List[float] = [float(val.strip()) for val in true_line.strip().split(',')]
for test_val, true_val, column_name in zip(test_vals, true_vals, column_names):
true_zero: bool = abs(true_val) < zero_tol
test_zero: bool = abs(test_val) < zero_tol
Expand All @@ -302,7 +302,7 @@ def diff_cgns(test_cgns: Path, true_cgns: Path, tolerance: float = 1e-12) -> str
"""
my_env: dict = os.environ.copy()

run_args: list[str] = ['cgnsdiff', '-d', '-t', f'{tolerance}', str(test_cgns), str(true_cgns)]
run_args: List[str] = ['cgnsdiff', '-d', '-t', f'{tolerance}', str(test_cgns), str(true_cgns)]
proc = subprocess.run(' '.join(run_args),
shell=True,
stdout=subprocess.PIPE,
Expand All @@ -329,7 +329,7 @@ def run_test(index: int, test: str, spec: TestSpec, backend: str,
TestCase: Test case result
"""
source_path: Path = suite_spec.get_source_path(test)
run_args: list = [suite_spec.get_run_path(test), *spec.args]
run_args: List = [suite_spec.get_run_path(test), *spec.args]

if '{ceed_resource}' in run_args:
run_args[run_args.index('{ceed_resource}')] = backend
Expand Down Expand Up @@ -362,11 +362,11 @@ def run_test(index: int, test: str, spec: TestSpec, backend: str,
stdout=proc.stdout.decode('utf-8'),
stderr=proc.stderr.decode('utf-8'),
allow_multiple_subelements=True)
ref_csvs: list[Path] = []
output_files: list[str] = [arg for arg in spec.args if 'ascii:' in arg]
ref_csvs: List[Path] = []
output_files: List[str] = [arg for arg in spec.args if 'ascii:' in arg]
if output_files:
ref_csvs = [suite_spec.get_output_path(test, file.split('ascii:')[-1]) for file in output_files]
ref_cgns: list[Path] = []
ref_cgns: List[Path] = []
output_files = [arg for arg in spec.args if 'cgns:' in arg]
if output_files:
ref_cgns = [suite_spec.get_output_path(test, file.split('cgns:')[-1]) for file in output_files]
Expand Down Expand Up @@ -471,7 +471,7 @@ def init_process():
my_env['CEED_ERROR_HANDLER'] = 'exit'


def run_tests(test: str, ceed_backends: list[str], mode: RunMode, nproc: int,
def run_tests(test: str, ceed_backends: List[str], mode: RunMode, nproc: int,
suite_spec: SuiteSpec, pool_size: int = 1) -> TestSuite:
"""Run all test cases for `test` with each of the provided `ceed_backends`
Expand All @@ -486,16 +486,16 @@ def run_tests(test: str, ceed_backends: list[str], mode: RunMode, nproc: int,
Returns:
TestSuite: JUnit `TestSuite` containing results of all test cases
"""
test_specs: list[TestSpec] = get_test_args(suite_spec.get_source_path(test))
test_specs: List[TestSpec] = get_test_args(suite_spec.get_source_path(test))
if mode is RunMode.TAP:
print('1..' + str(len(test_specs) * len(ceed_backends)))

# list of (test, test_specs, ceed_backend, ...) tuples generated from list of backends and test specs
args: list[TestCase] = [(i, test, spec, backend, mode, nproc, suite_spec)
args: List[TestCase] = [(i, test, spec, backend, mode, nproc, suite_spec)
for i, (spec, backend) in enumerate(product(test_specs, ceed_backends), start=1)]

with mp.Pool(processes=pool_size, initializer=init_process) as pool:
async_outputs: list[mp.AsyncResult] = [pool.apply_async(run_test, argv) for argv in args]
async_outputs: List[mp.AsyncResult] = [pool.apply_async(run_test, argv) for argv in args]

test_cases = []
for async_output in async_outputs:
Expand Down

0 comments on commit a77d276

Please sign in to comment.