Skip to content

Commit

Permalink
Merge pull request #312 from DedalusProject/eval_copy
Browse files Browse the repository at this point in the history
Refined task processing
  • Loading branch information
kburns authored Jan 22, 2025
2 parents 7925795 + 1475515 commit 9433d34
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 23 deletions.
3 changes: 3 additions & 0 deletions dedalus/core/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import OrderedDict
from math import prod
import numbers
from weakref import WeakSet

from .coords import CoordinateSystem, DirectProduct
from ..tools.array import reshape_vector
Expand Down Expand Up @@ -112,6 +113,8 @@ def __init__(self, coordsystems, comm=None, mesh=None, dtype=None):
self.comm_coords = np.array(self.comm_cart.coords, dtype=int)
# Build layout objects
self._build_layouts()
# Keep set of weak field references
self.fields = WeakSet()

@CachedAttribute
def cs_by_axis(self):
Expand Down
29 changes: 13 additions & 16 deletions dedalus/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .future import FutureField, FutureLockedField
from .field import Field, LockedField
from .operators import Copy
from ..tools.cache import CachedAttribute
from ..tools.general import OrderedSet
from ..tools.general import oscillate
Expand Down Expand Up @@ -127,22 +128,19 @@ def evaluate_handlers(self, handlers, id=None, **kw):
# Attempt evaluation
tasks = self.attempt_tasks(tasks, id=id)

# # Transform all outputs to coefficient layout to dealias
## D3 note: need to worry about this for redundent tasks?
# outputs = OrderedSet([t['out'] for h in handlers for t in h.tasks])
# self.require_coeff_space(outputs)

# # Copy redundant outputs so processing is independent
# outputs = set()
# for handler in handlers:
# for task in handler.tasks:
# if task['out'] in outputs:
# task['out'] = task['out'].copy()
# else:
# outputs.add(task['out'])
# Transform all outputs to coefficient layout to dealias
outputs = OrderedSet([t['out'] for h in handlers for t in h.tasks if not isinstance(t['out'], LockedField)])
self.require_coeff_space(outputs)

# Copy redundant outputs so processing is independent
outputs = set()
for handler in handlers:
for task in handler.tasks:
if task['out'] in outputs:
task['out'] = task['out'].copy()
else:
outputs.add(task['out'])

# Process
for handler in handlers:
handler.process(**kw)
Expand Down Expand Up @@ -285,10 +283,9 @@ def add_task(self, task, layout='g', name=None, scales=None):
# Create operator
if isinstance(task, str):
op = FutureField.parse(task, self.vars, self.dist)
elif isinstance(task, Field):
op = Copy(task)
else:
# op = FutureField.cast(task, self.domain)
# op = Cast(task)
# TODO: figure out if we need to copying here
op = task
# Check scales
if isinstance(op, (LockedField, FutureLockedField)):
Expand Down
9 changes: 9 additions & 0 deletions dedalus/core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,8 @@ def __init__(self, dist, bases=None, name=None, tensorsig=None, dtype=None):
self.layout = self.dist.get_layout_object('c')
# Change scales to build buffer and data
self.preset_scales((1,) * self.dist.dim)
# Add weak reference to distributor
dist.fields.add(self)

def __getitem__(self, layout):
"""Return data viewed in specified layout."""
Expand Down Expand Up @@ -1022,3 +1024,10 @@ def lock_to_layouts(self, *layouts):
def lock_axis_to_grid(self, axis):
self.allowed_layouts = tuple(l for l in self.dist.layouts if l.grid_space[axis])

def unlock(self):
"""Return regular Field object with same data and no layout locking."""
field = Field(self.dist, bases=self.domain.bases, name=self.name, tensorsig=self.tensorsig, dtype=self.dtype)
field.preset_scales(self.scales)
field[self.layout] = self.data
return field

4 changes: 3 additions & 1 deletion dedalus/core/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .domain import Domain
from . import coords
from .field import Operand, Field
from .field import Operand, Field, LockedField
from .future import Future, FutureField, FutureLockedField
from ..tools.array import reshape_vector, apply_matrix, add_sparse, axindex, axslice, perm_matrix, copyto, sparse_block_diag, interleave_matrices
from ..tools.cache import CachedAttribute, CachedMethod
Expand Down Expand Up @@ -1492,6 +1492,8 @@ class Copy(LinearOperator):

def __init__(self, operand, out=None):
super().__init__(operand, out=out)
if isinstance(operand, (LockedField, FutureLockedField)):
raise ValueError("Not yet implemented for locked fields.")
# LinearOperator requirements
self.operand = operand
# FutureField requirements
Expand Down
6 changes: 5 additions & 1 deletion dedalus/core/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from collections import ChainMap

from .field import Operand, Field
from .field import Operand, Field, LockedField
from . import arithmetic
from . import operators
from . import solvers
Expand Down Expand Up @@ -244,6 +244,10 @@ def _build_matrix_expressions(self, eqn):
# Extract matrix expressions
F = eqn['LHS'] - eqn['RHS']
dF = F.frechet_differential(vars, perts)
# Remove any field locks
dF = dF.replace(operators.Lock, lambda x: x)
for field in dF.atoms(LockedField):
dF = dF.replace(field, field.unlock())
# Reinitialize and prep NCCs
dF = dF.reinitialize(ncc=True, ncc_vars=perts)
dF.prep_nccs(vars=perts)
Expand Down
10 changes: 5 additions & 5 deletions dedalus/core/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,11 +685,6 @@ def step(self, dt):
# Assert finite timestep
if not np.isfinite(dt):
raise ValueError("Invalid timestep")
# Enforce Hermitian symmetry for real variables
if self.enforce_real_cadence:
# Enforce for as many iterations as timestepper uses internally
if self.iteration % self.enforce_real_cadence < self.timestepper.steps:
self.enforce_hermitian_symmetry(self.state)
# Record times
wall_time = self.wall_time
if self.iteration == self.initial_iteration:
Expand All @@ -706,6 +701,11 @@ def step(self, dt):
self.run_time_start = self.wall_time
# Advance using timestepper
self.timestepper.step(dt, wall_time)
# Enforce Hermitian symmetry for real variables
if self.enforce_real_cadence:
# Enforce for as many iterations as timestepper uses internally
if self.iteration % self.enforce_real_cadence < self.timestepper.steps:
self.enforce_hermitian_symmetry(self.state)
# Update iteration
self.iteration += 1
self.dt = dt
Expand Down

0 comments on commit 9433d34

Please sign in to comment.