Skip to content

Commit

Permalink
Allow cache_coord to receive coordinates in different system
Browse files Browse the repository at this point in the history
  • Loading branch information
cmacmackin committed Sep 23, 2024
1 parent eb006ad commit 1c8fe5a
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions neso_fame/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from collections import defaultdict
from collections.abc import (
Iterable,
Iterator,
Expand Down Expand Up @@ -627,6 +628,7 @@ def empty_coord(
@dataclass(frozen=True)
class _CoordRecord:
ctype: Type[Coord] | Type[SliceCoord]
system: CoordinateSystem
index: int


Expand All @@ -638,16 +640,26 @@ def coord_cache(
def decorator(
func: Callable[P, T],
) -> Callable[P, T]:
coord_data = MutableCoordMap.empty_coord(int, rtol, atol)
slicecoord_data = MutableCoordMap.empty_slicecoord(int, rtol, atol)
coord_data: dict[CoordinateSystem, MutableCoordMap[Coord, int]] = defaultdict(
lambda: MutableCoordMap.empty_coord(int, rtol, atol)
)
slicecoord_data: dict[CoordinateSystem, MutableCoordMap[SliceCoord, int]] = (
defaultdict(lambda: MutableCoordMap.empty_slicecoord(int, rtol, atol))
)
cache_data: dict[tuple, T] = {}

def process_arg(x: object) -> object:
if isinstance(x, Coord):
return _CoordRecord(Coord, coord_data.setdefault(x, len(coord_data)))
sys = x.system
dat = coord_data[sys]
return _CoordRecord(Coord, sys, dat.setdefault(x, len(dat)))
elif isinstance(x, SliceCoord):
sys = x.system
sdat = slicecoord_data[sys]
return _CoordRecord(
SliceCoord, slicecoord_data.setdefault(x, len(slicecoord_data))
SliceCoord,
sys,
sdat.setdefault(x, len(sdat)),
)
return x

Expand Down

0 comments on commit 1c8fe5a

Please sign in to comment.