Skip to content

Commit

Permalink
Formatting and linting
Browse files Browse the repository at this point in the history
  • Loading branch information
cmacmackin committed Nov 24, 2023
1 parent 464bdd1 commit 5059c9e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 14 deletions.
61 changes: 48 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,18 @@ def make_line(point1: Pair, point2: Pair) -> mesh.StraightLine:
make_line(sorted_corners[3], sorted_corners[0]),
)


def _get_alignment(fixed1: bool, fixed2: bool) -> mesh.QuadAlignment:
return mesh.QuadAlignment.NONALIGNED if fixed1 and fixed2 else mesh.QuadAlignment.SOUTH if fixed1 else mesh.QuadAlignment.NORTH if fixed2 else mesh.QuadAlignment.ALIGNED
return (
mesh.QuadAlignment.NONALIGNED
if fixed1 and fixed2
else mesh.QuadAlignment.SOUTH
if fixed1
else mesh.QuadAlignment.NORTH
if fixed2
else mesh.QuadAlignment.ALIGNED
)


def trapezohedronal_hex(
a1: float,
Expand All @@ -305,7 +315,7 @@ def trapezohedronal_hex(
division: int,
num_divisions: int,
offset: float,
fixed_edges: tuple[bool, bool, bool, bool]
fixed_edges: tuple[bool, bool, bool, bool],
) -> Optional[mesh.Hex]:
centre = (
sum(map(operator.itemgetter(0), starts)),
Expand Down Expand Up @@ -432,7 +442,7 @@ def curved_hex(
division: int,
num_divisions: int,
offset: float,
fixed_edges: tuple[bool, bool, bool, bool]
fixed_edges: tuple[bool, bool, bool, bool],
) -> mesh.Hex:
sorted_starts = sorted(starts, key=operator.itemgetter(0))
sorted_starts = sorted(sorted_starts[0:2], key=operator.itemgetter(1)) + sorted(
Expand Down Expand Up @@ -560,10 +570,13 @@ def make_shape(starts: tuple[Pair, Pair]) -> mesh.AcrossFieldCurve:

points = np.linspace(limits[0], limits[1], num_quads + 1)
fixed = [left_fixed] + [False] * (num_quads - 1) + [right_fixed]

return [
mesh.Quad(shape, trace, a3, aligned_edges=align)
for shape, align in zip(map(make_shape, itertools.pairwise(points)), itertools.starmap(_get_alignment, itertools.pairwise(fixed)))
for shape, align in zip(
map(make_shape, itertools.pairwise(points)),
itertools.starmap(_get_alignment, itertools.pairwise(fixed)),
)
]


Expand Down Expand Up @@ -595,21 +608,43 @@ def make_line(start: Pair, end: Pair) -> mesh.AcrossFieldCurve:
mesh.SliceCoord(end[0], end[1], c),
)

def get_alignment(is_bound: bool, north_bound: bool, south_bound: bool) -> mesh.QuadAlignment:
def get_alignment(
is_bound: bool, north_bound: bool, south_bound: bool
) -> mesh.QuadAlignment:
if not fixed_bounds:
return mesh.QuadAlignment.ALIGNED
if is_bound:
return mesh.QuadAlignment.NONALIGNED
return _get_alignment(north_bound, south_bound)

def make_element_and_bounds(
pairs: list[Pair], is_bound: list[bool]
) -> tuple[mesh.Hex, list[frozenset[mesh.Quad]]]:
edges = [
mesh.Quad(make_line(pairs[0], pairs[1]), trace, a3, aligned_edges=get_alignment(is_bound[0], is_bound[2], is_bound[3])),
mesh.Quad(make_line(pairs[2], pairs[3]), trace, a3, aligned_edges=get_alignment(is_bound[1], is_bound[2], is_bound[3])),
mesh.Quad(make_line(pairs[0], pairs[2]), trace, a3, aligned_edges=get_alignment(is_bound[2], is_bound[0], is_bound[1])),
mesh.Quad(make_line(pairs[1], pairs[3]), trace, a3, aligned_edges=get_alignment(is_bound[3], is_bound[0], is_bound[1])),
mesh.Quad(
make_line(pairs[0], pairs[1]),
trace,
a3,
aligned_edges=get_alignment(is_bound[0], is_bound[2], is_bound[3]),
),
mesh.Quad(
make_line(pairs[2], pairs[3]),
trace,
a3,
aligned_edges=get_alignment(is_bound[1], is_bound[2], is_bound[3]),
),
mesh.Quad(
make_line(pairs[0], pairs[2]),
trace,
a3,
aligned_edges=get_alignment(is_bound[2], is_bound[0], is_bound[1]),
),
mesh.Quad(
make_line(pairs[1], pairs[3]),
trace,
a3,
aligned_edges=get_alignment(is_bound[3], is_bound[0], is_bound[1]),
),
]
return mesh.Hex(*edges), [
frozenset({e}) if b else frozenset() for e, b in zip(edges, is_bound)
Expand Down Expand Up @@ -886,7 +921,7 @@ def hex_starts(
_divisions,
_num_divisions,
whole_numbers,
fixed_edges
fixed_edges,
)

flat_quad = one_of(linear_quad, nonlinear_quad)
Expand Down Expand Up @@ -938,7 +973,7 @@ def hex_starts(
coordinate_systems,
integers(2, 10),
booleans(),
booleans()
booleans(),
).filter(lambda x: x is not None),
)
quad_mesh_arguments = quad_mesh_elements.map(lambda x: (x, get_quad_boundaries(x)))
Expand Down
1 change: 0 additions & 1 deletion tests/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
cylindrical_field_line,
cylindrical_field_trace,
linear_field_trace,
linear_quad,
mesh_arguments,
mutually_broadcastable_arrays,
non_nans,
Expand Down

0 comments on commit 5059c9e

Please sign in to comment.