Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
cmacmackin committed Nov 12, 2024
1 parent 4e469f4 commit efb3e87
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 79 deletions.
3 changes: 1 addition & 2 deletions neso_fame/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
overload,
)

from _pytest.compat import assert_never

import numpy as np
import numpy.typing as npt
from _pytest.compat import assert_never
from typing_extensions import Self

from neso_fame.coordinates import (
Expand Down
7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def slice_coord_for_system(
just(system),
)


def coord_for_system(
system: coordinates.CoordinateSystem,
) -> SearchStrategy[coordinates.SliceCoord]:
Expand All @@ -140,7 +141,11 @@ def coord_for_system(
register_type_strategy(
coordinates.SliceCoords,
builds(
lambda xs, c: coordinates.SliceCoords(np.abs(xs[0]) if c == coordinates.CoordinateSystem.CYLINDRICAL else xs[0], xs[1], c),
lambda xs, c: coordinates.SliceCoords(
np.abs(xs[0]) if c == coordinates.CoordinateSystem.CYLINDRICAL else xs[0],
xs[1],
c,
),
mutually_broadcastable_arrays(2),
sampled_from(coordinates.CoordinateSystem),
),
Expand Down
5 changes: 1 addition & 4 deletions tests/test_element_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from neso_fame.element_builder import ElementBuilder
from neso_fame.generators import _get_element_corners
from neso_fame.mesh import (
FieldTrace,
straight_line_across_field,
)

Expand Down Expand Up @@ -748,9 +747,7 @@ def _element_corners(
MOCK_MESH.equilibrium.o_point = Point2D(1.0, 0.0)

BUILDER = ElementBuilder(MOCK_MESH, simple_trace, 0.1, EMPTY_MAP)
BUILDER_UNFINISHED = ElementBuilder(
MOCK_MESH, simple_trace, 0.1, EMPTY_MAP
)
BUILDER_UNFINISHED = ElementBuilder(MOCK_MESH, simple_trace, 0.1, EMPTY_MAP)
# with (
# patch(
# "neso_fame.element_builder.flux_surface_edge",
Expand Down
10 changes: 2 additions & 8 deletions tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,15 +948,9 @@ def test_extruding_hypnotoad_mesh_fill_core() -> None:
o_point = SliceCoord(eq.o_point.R, eq.o_point.Z, CoordinateSystem.CYLINDRICAL)

def get_axis_edge(prism: Prism) -> FieldAlignedCurve:
curves = frozenset(q.north for q in prism) | frozenset(
q.south for q in prism
)
curves = frozenset(q.north for q in prism) | frozenset(q.south for q in prism)
assert len(curves) == 3
axis_curve = [
c
for c in curves
if c.start_points.to_coord() == o_point
]
axis_curve = [c for c in curves if c.start_points.to_coord() == o_point]
assert len(axis_curve) == 1
acurve = axis_curve[0]
return acurve
Expand Down
13 changes: 3 additions & 10 deletions tests/test_hypnotoad.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
floats,
integers,
just,
lists,
nothing,
one_of,
sampled_from,
Expand Down Expand Up @@ -975,9 +974,7 @@ def test_flux_surface_bounds(region: MeshRegion, dx3: float) -> None:
eq = region.meshParent.equilibrium

def constructor(north: SliceCoord, south: SliceCoord) -> Quad:
return Quad(
straight_line_across_field(north, south), simple_trace, dx3
)
return Quad(straight_line_across_field(north, south), simple_trace, dx3)

for points in get_region_flux_surface_boundary_points(region):
check_flux_surface_bound(
Expand All @@ -993,9 +990,7 @@ def test_perpendicular_bounds(region: MeshRegion, dx3: float) -> None:
eq = region.meshParent.equilibrium

def constructor(north: SliceCoord, south: SliceCoord) -> Quad:
return Quad(
straight_line_across_field(north, south), simple_trace, dx3
)
return Quad(straight_line_across_field(north, south), simple_trace, dx3)

for points in get_region_perpendicular_boundary_points(region):
check_perpendicular_bounds(
Expand Down Expand Up @@ -1092,9 +1087,7 @@ def test_region_bounds(
)
def test_mesh_bounds(mesh_args: Mesh, is_boundary: list[bool]) -> None:
def constructor(north: SliceCoord, south: SliceCoord) -> Quad:
return Quad(
straight_line_across_field(north, south), simple_trace, 1.0
)
return Quad(straight_line_across_field(north, south), simple_trace, 1.0)

mesh = to_mesh(mesh_args)
eq = mesh.equilibrium
Expand Down
131 changes: 85 additions & 46 deletions tests/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
sampled_from,
shared,
slices,
tuples
)

from neso_fame import coordinates, mesh
Expand All @@ -30,6 +29,8 @@
from .conftest import (
_hex_mesh_arguments,
_quad_mesh_elements,
common_coords,
common_slice_coords,
compatible_coords_and_alignments,
corners_to_poloidal_quad,
flat_sided_hex,
Expand All @@ -45,14 +46,12 @@
subdivideable_quad,
unbroadcastable_shape,
whole_numbers,
common_slice_coords,
common_coords,
)


divisions = shared(integers(-5, 10), key=5545)
pos_divisions = shared(integers(1, 10), key=5546)


@given(
compatible_coords_and_alignments,
linear_traces,
Expand Down Expand Up @@ -88,7 +87,7 @@ def test_subdivideable_field_aligned_positions(
linear_traces,
floats(0.1, 10.0),
integers(1, 10),
pos_divisions.flatmap(lambda x: integers(0, x-1)),
pos_divisions.flatmap(lambda x: integers(0, x - 1)),
pos_divisions,
)
def test_field_aligned_positions(
Expand Down Expand Up @@ -153,7 +152,7 @@ def test_field_aligned_positions_bad_divisions(
num_divisions: int,
) -> None:
starts, alignments = coords_alignemnts
with pytest.raises(ValueError, match=r'.*num_divisions.*'):
with pytest.raises(ValueError, match=r".*num_divisions.*"):
mesh.field_aligned_positions(
starts, dx3, field, alignments, order, 0, num_divisions
)
Expand Down Expand Up @@ -255,19 +254,19 @@ def test_field_aligned_positions_subdivide(
data = array
base += m


def test_field_aligned_positions_bad_subdivide() -> None:
starts = coordinates.SliceCoords(
np.array([0.0, 0.25, 0.5, 0.75, 1.0]),
np.array([-1.0, -0.5, 0.0, 0.5, 1.0]),
coordinates.CoordinateSystem.CARTESIAN,
)
alignments = np.array([1.0, 1.0, 1.0, 1.0, 0.0])
data = mesh.field_aligned_positions(
starts, 2.0, sample_trace, alignments, 4, 0, 1
)
data = mesh.field_aligned_positions(starts, 2.0, sample_trace, alignments, 4, 0, 1)
with pytest.raises(ValueError, match=r"Can not subdivide.*"):
next(data.subdivide(3))


@given(
compatible_coords_and_alignments,
linear_traces,
Expand Down Expand Up @@ -295,6 +294,7 @@ def test_field_aligned_positions_subdivide_self(

shared_coords_alignments = shared(compatible_coords_and_alignments, key=22222)


@given(
compatible_coords_and_alignments,
linear_traces,
Expand All @@ -316,7 +316,17 @@ def test_field_aligned_positions_poloidal_shape(
starts, dx3, field, alignments, order, subdivision, num_divisions
)
assert len(data.poloidal_shape) == len(starts.shape)
assert data.poloidal_shape == tuple(map(max, itertools.zip_longest(starts.shape[::-1], alignments.shape[::-1], fillvalue=0)))[::-1]
assert (
data.poloidal_shape
== tuple(
map(
max,
itertools.zip_longest(
starts.shape[::-1], alignments.shape[::-1], fillvalue=0
),
)
)[::-1]
)


@settings(report_multiple_bugs=False)
Expand Down Expand Up @@ -373,12 +383,16 @@ def test_field_aligned_positions_getitem(
)
result = data[idx]
assert result.start_points is not data.start_points
assert result.start_points.x1.base is data.start_points.x1 or not isinstance(
result.start_points.x1, np.ndarray
) or not isinstance(data.start_points.x1, np.ndarray)
assert result.start_points.x2.base is data.start_points.x2 or not isinstance(
result.start_points.x1, np.ndarray
) or not isinstance(data.start_points.x2, np.ndarray)
assert (
result.start_points.x1.base is data.start_points.x1
or not isinstance(result.start_points.x1, np.ndarray)
or not isinstance(data.start_points.x1, np.ndarray)
)
assert (
result.start_points.x2.base is data.start_points.x2
or not isinstance(result.start_points.x1, np.ndarray)
or not isinstance(data.start_points.x2, np.ndarray)
)
assert result.x3 is data.x3
assert result.trace is data.trace
assert result.alignments is not data.alignments
Expand Down Expand Up @@ -572,9 +586,7 @@ def test_field_aligned_positions_slice_caching() -> None:
alignments = np.array([1.0, 1.0, 0.0])
trace = MagicMock()
trace.side_effect = sample_trace
data = mesh.field_aligned_positions(
starts, 2.0, trace, alignments, 4
)
data = mesh.field_aligned_positions(starts, 2.0, trace, alignments, 4)
near = data[0]
expected_x1 = np.array(
[
Expand Down Expand Up @@ -620,21 +632,23 @@ def example_trace(


def test_straight_line_across_field() -> None:
line = mesh.control_points(mesh.straight_line_across_field(
mesh.SliceCoord(1., 1., coordinates.CoordinateSystem.CARTESIAN),
mesh.SliceCoord(2., 0., coordinates.CoordinateSystem.CARTESIAN),
4
))
line = mesh.control_points(
mesh.straight_line_across_field(
mesh.SliceCoord(1.0, 1.0, coordinates.CoordinateSystem.CARTESIAN),
mesh.SliceCoord(2.0, 0.0, coordinates.CoordinateSystem.CARTESIAN),
4,
)
)
np.testing.assert_allclose(line.x1, np.linspace(1, 2, 5), 1e-10, 1e-10)
np.testing.assert_allclose(line.x2, np.linspace(1, 0, 5), 1e-10, 1e-10)


def test_straight_line_across_field_mismatched_coords() -> None:
with pytest.raises(ValueError):
mesh.straight_line_across_field(
mesh.SliceCoord(1., 1., coordinates.CoordinateSystem.CARTESIAN),
mesh.SliceCoord(1., 1., coordinates.CoordinateSystem.CYLINDRICAL),
3
mesh.SliceCoord(1.0, 1.0, coordinates.CoordinateSystem.CARTESIAN),
mesh.SliceCoord(1.0, 1.0, coordinates.CoordinateSystem.CYLINDRICAL),
3,
)


Expand All @@ -646,6 +660,7 @@ def test_bad_straight_line() -> None:
4,
)


line_termini = sampled_from(coordinates.CoordinateSystem).flatmap(
lambda c: lists(
builds(mesh.Coord, whole_numbers, whole_numbers, whole_numbers, just(c)),
Expand Down Expand Up @@ -1170,35 +1185,59 @@ def test_mesh_len(m: mesh.GenericMesh) -> None:

shared_orders = shared(integers(1, 10), key=123456)
coords = builds(
coordinates.Coord, whole_numbers, whole_numbers, whole_numbers, just(coordinates.CoordinateSystem.CARTESIAN))
coordinates.Coord,
whole_numbers,
whole_numbers,
whole_numbers,
just(coordinates.CoordinateSystem.CARTESIAN),
)
lines_across_field = builds(
mesh.straight_line_across_field,
common_slice_coords,
common_slice_coords,
shared_orders
)
mesh.straight_line_across_field,
common_slice_coords,
common_slice_coords,
shared_orders,
)
geometries = one_of(
(
lines_across_field,
builds(mesh.straight_line, common_coords, common_coords, shared_orders),
builds(mesh.field_aligned_positions, lines_across_field, floats(1e-3, 1e3), linear_traces, floats(0., 1.0).map(np.array), shared_orders, pos_divisions.flatmap(lambda x: integers(0, x-1)), pos_divisions).map(mesh.Quad),
builds(
mesh.field_aligned_positions,
lines_across_field,
floats(1e-3, 1e3),
linear_traces,
floats(0.0, 1.0).map(np.array),
shared_orders,
pos_divisions.flatmap(lambda x: integers(0, x - 1)),
pos_divisions,
).map(mesh.Quad),
builds(
mesh.Prism,
sampled_from(tuple(mesh.PrismTypes)[::-1]),
builds(mesh.field_aligned_positions,
builds(corners_to_poloidal_quad, shared_orders, lists(common_slice_coords.map(tuple), min_size=4, max_size=4).map(tuple), shared_coordinate_systems),
floats(1e-3, 1e3),
linear_traces,
floats(0., 1.).map(np.array),
shared_orders,
pos_divisions.flatmap(lambda x: integers(0, x-1)),
pos_divisions,
)
)
builds(
mesh.field_aligned_positions,
builds(
corners_to_poloidal_quad,
shared_orders,
lists(common_slice_coords.map(tuple), min_size=4, max_size=4).map(
tuple
),
shared_coordinate_systems,
),
floats(1e-3, 1e3),
linear_traces,
floats(0.0, 1.0).map(np.array),
shared_orders,
pos_divisions.flatmap(lambda x: integers(0, x - 1)),
pos_divisions,
),
),
)
)


@given(shared_orders, geometries)
def test_order(n: int, geom: mesh.AcrossFieldCurve | mesh.Curve | mesh.Quad | mesh.Prism):
def test_order(
n: int, geom: mesh.AcrossFieldCurve | mesh.Curve | mesh.Quad | mesh.Prism
) -> None:
assert n == mesh.order(geom)

Loading

0 comments on commit efb3e87

Please sign in to comment.