diff --git a/tests/conftest.py b/tests/conftest.py index ea5e650..4391203 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -291,8 +291,18 @@ def make_line(point1: Pair, point2: Pair) -> mesh.StraightLine: make_line(sorted_corners[3], sorted_corners[0]), ) + def _get_alignment(fixed1: bool, fixed2: bool) -> mesh.QuadAlignment: - return mesh.QuadAlignment.NONALIGNED if fixed1 and fixed2 else mesh.QuadAlignment.SOUTH if fixed1 else mesh.QuadAlignment.NORTH if fixed2 else mesh.QuadAlignment.ALIGNED + return ( + mesh.QuadAlignment.NONALIGNED + if fixed1 and fixed2 + else mesh.QuadAlignment.SOUTH + if fixed1 + else mesh.QuadAlignment.NORTH + if fixed2 + else mesh.QuadAlignment.ALIGNED + ) + def trapezohedronal_hex( a1: float, @@ -305,7 +315,7 @@ def trapezohedronal_hex( division: int, num_divisions: int, offset: float, - fixed_edges: tuple[bool, bool, bool, bool] + fixed_edges: tuple[bool, bool, bool, bool], ) -> Optional[mesh.Hex]: centre = ( sum(map(operator.itemgetter(0), starts)), @@ -432,7 +442,7 @@ def curved_hex( division: int, num_divisions: int, offset: float, - fixed_edges: tuple[bool, bool, bool, bool] + fixed_edges: tuple[bool, bool, bool, bool], ) -> mesh.Hex: sorted_starts = sorted(starts, key=operator.itemgetter(0)) sorted_starts = sorted(sorted_starts[0:2], key=operator.itemgetter(1)) + sorted( @@ -560,10 +570,13 @@ def make_shape(starts: tuple[Pair, Pair]) -> mesh.AcrossFieldCurve: points = np.linspace(limits[0], limits[1], num_quads + 1) fixed = [left_fixed] + [False] * (num_quads - 1) + [right_fixed] - + return [ mesh.Quad(shape, trace, a3, aligned_edges=align) - for shape, align in zip(map(make_shape, itertools.pairwise(points)), itertools.starmap(_get_alignment, itertools.pairwise(fixed))) + for shape, align in zip( + map(make_shape, itertools.pairwise(points)), + itertools.starmap(_get_alignment, itertools.pairwise(fixed)), + ) ] @@ -595,21 +608,43 @@ def make_line(start: Pair, end: Pair) -> mesh.AcrossFieldCurve: mesh.SliceCoord(end[0], end[1], c), ) - def get_alignment(is_bound: bool, north_bound: bool, south_bound: bool) -> mesh.QuadAlignment: + def get_alignment( + is_bound: bool, north_bound: bool, south_bound: bool + ) -> mesh.QuadAlignment: if not fixed_bounds: return mesh.QuadAlignment.ALIGNED if is_bound: return mesh.QuadAlignment.NONALIGNED return _get_alignment(north_bound, south_bound) - + def make_element_and_bounds( pairs: list[Pair], is_bound: list[bool] ) -> tuple[mesh.Hex, list[frozenset[mesh.Quad]]]: edges = [ - mesh.Quad(make_line(pairs[0], pairs[1]), trace, a3, aligned_edges=get_alignment(is_bound[0], is_bound[2], is_bound[3])), - mesh.Quad(make_line(pairs[2], pairs[3]), trace, a3, aligned_edges=get_alignment(is_bound[1], is_bound[2], is_bound[3])), - mesh.Quad(make_line(pairs[0], pairs[2]), trace, a3, aligned_edges=get_alignment(is_bound[2], is_bound[0], is_bound[1])), - mesh.Quad(make_line(pairs[1], pairs[3]), trace, a3, aligned_edges=get_alignment(is_bound[3], is_bound[0], is_bound[1])), + mesh.Quad( + make_line(pairs[0], pairs[1]), + trace, + a3, + aligned_edges=get_alignment(is_bound[0], is_bound[2], is_bound[3]), + ), + mesh.Quad( + make_line(pairs[2], pairs[3]), + trace, + a3, + aligned_edges=get_alignment(is_bound[1], is_bound[2], is_bound[3]), + ), + mesh.Quad( + make_line(pairs[0], pairs[2]), + trace, + a3, + aligned_edges=get_alignment(is_bound[2], is_bound[0], is_bound[1]), + ), + mesh.Quad( + make_line(pairs[1], pairs[3]), + trace, + a3, + aligned_edges=get_alignment(is_bound[3], is_bound[0], is_bound[1]), + ), ] return mesh.Hex(*edges), [ frozenset({e}) if b else frozenset() for e, b in zip(edges, is_bound) @@ -886,7 +921,7 @@ def hex_starts( _divisions, _num_divisions, whole_numbers, - fixed_edges + fixed_edges, ) flat_quad = one_of(linear_quad, nonlinear_quad) @@ -938,7 +973,7 @@ def hex_starts( coordinate_systems, integers(2, 10), booleans(), - booleans() + booleans(), ).filter(lambda x: x is not None), ) quad_mesh_arguments = quad_mesh_elements.map(lambda x: (x, get_quad_boundaries(x))) diff --git a/tests/test_mesh.py b/tests/test_mesh.py index 5cd358c..d77c346 100644 --- a/tests/test_mesh.py +++ b/tests/test_mesh.py @@ -28,7 +28,6 @@ cylindrical_field_line, cylindrical_field_trace, linear_field_trace, - linear_quad, mesh_arguments, mutually_broadcastable_arrays, non_nans,