Skip to content

Commit

Permalink
Switch to using an R-tree to determine equality of points
Browse files Browse the repository at this point in the history
  • Loading branch information
cmacmackin committed Sep 13, 2024
1 parent 6d5156f commit a5a8653
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 8 deletions.
57 changes: 55 additions & 2 deletions neso_fame/nektar_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
from dataclasses import dataclass
from functools import cache, reduce
from operator import attrgetter, or_
from typing import Optional, cast
from typing import Callable, Optional, TypeVar, cast

import NekPy.LibUtilities as LU
import NekPy.SpatialDomains as SD
import numpy as np
import numpy.typing as npt
from rtree.index import Index, Property

from .mesh import (
Coord,
CoordinateSystem,
Coords,
EndShape,
Mesh,
Expand Down Expand Up @@ -234,7 +236,55 @@ def reset_id_counts() -> None:
_solid_count = -1


@cache
T = TypeVar("T")
atol = 1e-8
rtol = 1e-8


def _get_bound_box(position: Coord, atol: float, rtol: float) -> tuple[float, ...]:
if position.system != CoordinateSystem.CARTESIAN:
raise ValueError(
"Should only construct bounding-boxes around cartesian coordinates"
)

def offset(x: float) -> tuple[float, float]:
dx = max(abs(atol), abs(x * rtol))
return x - dx, x + dx

return tuple(itertools.chain.from_iterable(map(offset, position)))


def rtree_cache(func: Callable[[Coord, int, int], T]) -> Callable[[Coord, int, int], T]:
"""Return a wrapped function that caches based on proximity of coordinates."""
# Create R-tree object
cache_data: dict[tuple[int, int], tuple[Index, list[T]]] = {}

def wrapper(position: Coord, spatial_dim: int, layer_id: int) -> T:
pos = position.to_cartesian()
idx = (spatial_dim, layer_id)
if idx not in cache_data:
obj = func(pos, spatial_dim, layer_id)
rtree = Index(interleaved=False, properties=Property(dimension=3))
rtree.insert(0, _get_bound_box(pos, atol, rtol))
cache_data[idx] = (
rtree,
[obj],
)
return obj
tree, objects = cache_data[idx]
bounds = _get_bound_box(pos, atol, rtol)
for item in tree.intersection(bounds, objects=False):
# Return the point for the first intersection found
return objects[item]
obj = func(pos, spatial_dim, layer_id)
tree.insert(len(objects), bounds)
objects.append(obj)
return obj

return wrapper


@rtree_cache
def nektar_point(position: Coord, spatial_dim: int, layer_id: int) -> SD.PointGeom:
"""Return a Nektar++ PointGeom object at the specified position.
Expand Down Expand Up @@ -594,6 +644,9 @@ def _nektar_prism(
return frozenset({nek_solid}), frozenset(faces), segments, points


# FIXME: Should this actually be cached? Caching on a Prism won't be
# reliable. Don't think we'll ever be processing a prism more than
# once anyway.
@cache
def nektar_3d_element(
solid: Prism, order: int, spatial_dim: int, layer_id: int
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"click",
"meshio",
"meshpy",
"rtree",
]
dynamic = ["version"]

Expand Down
13 changes: 8 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def arbitrary_arrays() -> SearchStrategy[npt.NDArray]:
return arrays(floating_dtypes(), array_shapes())


WHOLE_NUM_MAX = 900
WHOLE_NUM_MAX = 1000
whole_numbers = integers(-WHOLE_NUM_MAX, WHOLE_NUM_MAX).map(float)
nonnegative_numbers = integers(1, WHOLE_NUM_MAX).map(float)
non_zero = whole_numbers.filter(bool)
Expand Down Expand Up @@ -97,6 +97,9 @@ def shape_to_array(


coordinate_systems = sampled_from(mesh.CoordinateSystem)
coordinate_systems3d = coordinate_systems.filter(
lambda x: x != mesh.CoordinateSystem.CARTESIAN2D
)


def slice_coord_for_system(
Expand Down Expand Up @@ -1150,7 +1153,7 @@ def hex_starts(
whole_numbers,
non_zero,
whole_numbers.flatmap(hex_starts),
coordinate_systems,
coordinate_systems3d,
floats(-2.0, 2.0),
integers(2, 5),
_divisions,
Expand All @@ -1169,7 +1172,7 @@ def hex_starts(
small_whole_numbers.flatmap(
lambda x: hex_starts(x, absmax=SMALL_WHOLE_NUM_MAX)
).map(lambda x: x[:3]),
coordinate_systems,
coordinate_systems3d,
floats(-2.0, 2.0),
integers(2, 5),
_divisions,
Expand Down Expand Up @@ -1328,7 +1331,7 @@ def hex_starts(
whole_numbers.flatmap(hex_starts),
integers(1, 3),
integers(1, 3),
coordinate_systems,
coordinate_systems3d,
integers(2, 5),
booleans(),
).filter(lambda x: x is not None),
Expand All @@ -1345,7 +1348,7 @@ def hex_starts(
),
integers(1, 3),
integers(1, 3),
coordinate_systems,
coordinate_systems3d,
integers(2, 5),
booleans(),
)
Expand Down
1 change: 0 additions & 1 deletion tests/test_nektar_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,6 @@ def test_poloidal_curve(q: Quad, s: float) -> None:
np.testing.assert_allclose(x2, expected.x2, 1e-8, 1e-8)
np.testing.assert_allclose(x3, 0.0, 1e-8, 1e-8)


@given(from_type(Prism), integers(1, 12), integers())
def test_nektar_poloidal_face(solid: Prism, order: int, layer: int) -> None:
shapes, segments, points = nektar_writer.nektar_poloidal_face(
Expand Down

0 comments on commit a5a8653

Please sign in to comment.