Skip to content

Commit

Permalink
FEniCS-style bcs
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 28, 2025
1 parent 58f3d6f commit 72bfb6e
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 44 deletions.
73 changes: 47 additions & 26 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def assemble(expr, *args, **kwargs):
`matrix.Matrix`.
is_base_form_preprocessed : bool
If `True`, skip preprocessing of the form.
current_state : firedrake.function.Function or None
If provided and ``zero_bc_nodes == True``, the boundary condition
nodes of the output are set to the residual of the boundary conditions.
Returns
-------
Expand Down Expand Up @@ -130,16 +133,21 @@ def assemble(expr, *args, **kwargs):
"""
if args:
raise RuntimeError(f"Got unexpected args: {args}")
tensor = kwargs.pop("tensor", None)
return get_assembler(expr, *args, **kwargs).assemble(tensor=tensor)

assemble_kwargs = {}
for key in ("tensor", "current_state"):
if key in kwargs:
assemble_kwargs[key] = kwargs.pop(key, None)
return get_assembler(expr, *args, **kwargs).assemble(**assemble_kwargs)


def get_assembler(form, *args, **kwargs):
"""Create an assembler.
Notes
-----
See `assemble` for descriptions of the parameters. ``tensor`` should not be passed to this function.
See `assemble` for descriptions of the parameters. ``tensor`` and
``current_state`` should not be passed to this function.
"""
is_base_form_preprocessed = kwargs.pop('is_base_form_preprocessed', False)
Expand Down Expand Up @@ -187,13 +195,15 @@ class ExprAssembler(object):
def __init__(self, expr):
self._expr = expr

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the pointwise expression.
Parameters
----------
tensor : firedrake.function.Function or firedrake.cofunction.Cofunction or matrix.MatrixBase
Output tensor.
current_state : None
Ignored by this class.
Returns
-------
Expand All @@ -205,6 +215,7 @@ def assemble(self, tensor=None):
from ufl.checks import is_scalar_constant_expression

assert tensor is None
assert current_state is None
expr = self._expr
# Get BaseFormOperators (e.g. `Interpolate` or `ExternalOperator`)
base_form_operators = extract_base_form_operators(expr)
Expand Down Expand Up @@ -274,13 +285,15 @@ def allocate(self):
"""Allocate memory for the output tensor."""

@abc.abstractmethod
def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the form.
Parameters
----------
tensor : firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase
Output tensor to contain the result of assembly; if `None`, a tensor of appropriate type is created.
current_state : firedrake.function.Function or None
If provided, the boundary condition nodes are set to the boundary condition residual.
Returns
-------
Expand Down Expand Up @@ -358,13 +371,15 @@ def allocation_integral_types(self):
else:
return self._allocation_integral_types

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the form.
Parameters
----------
tensor : firedrake.cofunction.Cofunction or firedrake.function.Function or matrix.MatrixBase
Output tensor to contain the result of assembly.
current_state : firedrake.function.Function or None
If provided, the boundary condition nodes are set to the boundary condition residual.
Returns
-------
Expand All @@ -389,7 +404,7 @@ def visitor(e, *operands):
rank = len(self._form.arguments())
if rank == 1 and not isinstance(result, ufl.ZeroBaseForm):
for bc in self._bcs:
bc.zero(result)
OneFormAssembler._apply_bc(self, result, bc, u=current_state)

if tensor:
BaseFormAssembler.update_tensor(result, tensor)
Expand Down Expand Up @@ -968,13 +983,15 @@ def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters)
self._needs_zeroing = needs_zeroing

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the form.
Parameters
----------
tensor : firedrake.cofunction.Cofunction or matrix.MatrixBase
Output tensor to contain the result of assembly; if `None`, a tensor of appropriate type is created.
current_state : firedrake.function.Function or None
If provided, the boundary condition nodes are set to the boundary condition residual.
Returns
-------
Expand All @@ -998,12 +1015,12 @@ def assemble(self, tensor=None):
self.execute_parloops(tensor)

for bc in self._bcs:
self._apply_bc(tensor, bc)
self._apply_bc(tensor, bc, u=current_state)

return self.result(tensor)

@abc.abstractmethod
def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
"""Apply boundary condition."""

@abc.abstractmethod
Expand Down Expand Up @@ -1138,7 +1155,7 @@ def allocate(self):
comm=self._form.ufl_domains()[0]._comm
)

def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
pass

def _check_tensor(self, tensor):
Expand Down Expand Up @@ -1199,26 +1216,29 @@ def allocate(self):
else:
raise RuntimeError(f"Not expected: found rank = {rank} and diagonal = {self._diagonal}")

def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
# TODO Maybe this could be a singledispatchmethod?
if isinstance(bc, DirichletBC):
self._apply_dirichlet_bc(tensor, bc)
if self._diagonal:
bc.set(tensor, self._weight)
elif self._zero_bc_nodes:
bc.zero(tensor)
else:
# The residual belongs to a mixed space that is dual on the boundary nodes
# and primal on the interior nodes. Therefore, this is type-safe operation.
r = tensor.riesz_representation("l2")
bc.apply(r, u=u)
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight).assemble(tensor=tensor)
OneFormAssembler(bc.f, bcs=bc.bcs,
form_compiler_parameters=self._form_compiler_params,
needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes,
diagonal=self._diagonal,
weight=self._weight).assemble(tensor=tensor, current_state=u)
else:
raise AssertionError

def _apply_dirichlet_bc(self, tensor, bc):
if self._diagonal:
bc.set(tensor, self._weight)
elif not self._zero_bc_nodes:
# NOTE this only works if tensor is a Function and not a Cofunction
bc.apply(tensor)
else:
bc.zero(tensor)

def _check_tensor(self, tensor):
if tensor.function_space() != self._form.arguments()[0].function_space().dual():
raise ValueError("Form's argument does not match provided result tensor")
Expand Down Expand Up @@ -1430,7 +1450,8 @@ def _all_assemblers(self):
all_assemblers.extend(_assembler._all_assemblers)
return tuple(all_assemblers)

def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
assert u is None
op2tensor = tensor.M
spaces = tuple(a.function_space() for a in tensor.a.arguments())
V = bc.function_space()
Expand Down Expand Up @@ -1534,7 +1555,7 @@ def allocate(self):
options_prefix=self._options_prefix,
appctx=self._appctx or {})

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
if tensor is None:
tensor = self.allocate()
else:
Expand Down
12 changes: 8 additions & 4 deletions firedrake/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def solve(*args, **kwargs):
To exclude Dirichlet boundary condition nodes through the use of a
:class`.RestrictedFunctionSpace`, set the ``restrict`` keyword
argument to be True.
To linearise around the initial guess before imposing boundary
conditions, set the ``pre_apply_bcs`` keyword argument to be False.
"""

assert len(args) > 0
Expand All @@ -151,7 +154,7 @@ def _solve_varproblem(*args, **kwargs):
eq, u, bcs, J, Jp, M, form_compiler_parameters, \
solver_parameters, nullspace, nullspace_T, \
near_nullspace, \
options_prefix, restrict = _extract_args(*args, **kwargs)
options_prefix, restrict, pre_apply_bcs = _extract_args(*args, **kwargs)

# Check whether solution is valid
if not isinstance(u, (function.Function, vector.Vector)):
Expand Down Expand Up @@ -184,7 +187,7 @@ def _solve_varproblem(*args, **kwargs):
# Create problem
problem = vs.NonlinearVariationalProblem(eq.lhs, u, bcs, J, Jp,
form_compiler_parameters=form_compiler_parameters,
restrict=restrict)
restrict=restrict, pre_apply_bcs=pre_apply_bcs)
# Create solver and call solve
solver = vs.NonlinearVariationalSolver(problem, solver_parameters=solver_parameters,
nullspace=nullspace,
Expand Down Expand Up @@ -297,7 +300,7 @@ def _extract_args(*args, **kwargs):
valid_kwargs = ["bcs", "J", "Jp", "M",
"form_compiler_parameters", "solver_parameters",
"nullspace", "transpose_nullspace", "near_nullspace",
"options_prefix", "appctx", "restrict"]
"options_prefix", "appctx", "restrict", "pre_apply_bcs"]
for kwarg in kwargs.keys():
if kwarg not in valid_kwargs:
raise RuntimeError("Illegal keyword argument '%s'; valid keywords \
Expand Down Expand Up @@ -341,10 +344,11 @@ def _extract_args(*args, **kwargs):
solver_parameters = kwargs.get("solver_parameters", {})
options_prefix = kwargs.get("options_prefix", None)
restrict = kwargs.get("restrict", False)
pre_apply_bcs = kwargs.get("pre_apply_bcs", True)

return eq, u, bcs, J, Jp, M, form_compiler_parameters, \
solver_parameters, nullspace, nullspace_T, near_nullspace, \
options_prefix, restrict
options_prefix, restrict, pre_apply_bcs


def _extract_bcs(bcs):
Expand Down
24 changes: 19 additions & 5 deletions firedrake/solving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import numpy

from pyop2 import op2
from firedrake import function, cofunction, dmhooks
from firedrake import dmhooks
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from firedrake.exceptions import ConvergenceError
from firedrake.petsc import PETSc, DEFAULT_KSP_PARAMETERS
from firedrake.formmanipulation import ExtractSubBlock
Expand Down Expand Up @@ -219,8 +221,15 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None,
self.bcs_J = tuple(bc.extract_form('J') for bc in problem.bcs)
self.bcs_Jp = tuple(bc.extract_form('Jp') for bc in problem.bcs)

self._bc_residual = None
if not problem.pre_apply_bcs:
# Delayed lifting of DirichletBCs
self._bc_residual = Function(self._x.function_space())
self.F -= self.J * self._bc_residual

self._assemble_residual = get_assembler(self.F, bcs=self.bcs_F,
form_compiler_parameters=self.fcp,
zero_bc_nodes=problem.pre_apply_bcs,
).assemble

self._jacobian_assembled = False
Expand Down Expand Up @@ -325,11 +334,11 @@ def split(self, fields):
pieces = [us[i].dat for i in field]
if len(pieces) == 1:
val, = pieces
subu = function.Function(V, val=val)
subu = Function(V, val=val)
subsplit = (subu, )
else:
val = op2.MixedDat(pieces)
subu = function.Function(V, val=val)
subu = Function(V, val=val)
# Split it apart to shove in the form.
subsplit = split(subu)
vec = []
Expand Down Expand Up @@ -395,7 +404,12 @@ def form_function(snes, X, F):
if ctx._pre_function_callback is not None:
ctx._pre_function_callback(X)

ctx._assemble_residual(tensor=ctx._F)
if ctx._bc_residual is not None:
# Delayed lifting of DirichletBC
ctx._bc_residual.zero()
for bc in ctx.bcs_F:
bc.apply(ctx._bc_residual, u=ctx._x)
ctx._assemble_residual(tensor=ctx._F, current_state=ctx._x)

if ctx._post_function_callback is not None:
with ctx._F.dat.vec as F_:
Expand Down Expand Up @@ -518,4 +532,4 @@ def _assemble_pjac(self):

@cached_property
def _F(self):
return cofunction.Cofunction(self.F.arguments()[0].function_space().dual())
return Cofunction(self.F.arguments()[0].function_space().dual())
17 changes: 12 additions & 5 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class NonlinearVariationalProblem(NonlinearVariationalProblemMixin):
def __init__(self, F, u, bcs=None, J=None,
Jp=None,
form_compiler_parameters=None,
is_linear=False, restrict=False):
is_linear=False, restrict=False, pre_apply_bcs=True):
r"""
:param F: the nonlinear form
:param u: the :class:`.Function` to solve for
Expand All @@ -68,6 +68,8 @@ def __init__(self, F, u, bcs=None, J=None,
:param restrict: (optional) If `True`, use restricted function spaces,
that exclude Dirichlet boundary condition nodes, internally for
the test and trial spaces.
:param pre_apply_bcs: (optional) If `False`, the problem is linearised
around the initial guess before imposing the boundary conditions.
"""
V = u.function_space()
self.output_space = V
Expand All @@ -86,6 +88,7 @@ def __init__(self, F, u, bcs=None, J=None,
if isinstance(bc, EquationBC):
restrict = False
self.restrict = restrict
self.pre_apply_bcs = pre_apply_bcs

if restrict and bcs:
V_res = restricted_function_space(V, extract_subdomain_ids(bcs))
Expand Down Expand Up @@ -304,8 +307,9 @@ def solve(self, bounds=None):
problem_dms = [V.dm for V in utils.unique(chain.from_iterable(c.function_space() for c in coefficients)) if V.dm != solution_dm]
problem_dms.append(solution_dm)

for dbc in problem.dirichlet_bcs():
dbc.apply(problem.u_restrict)
if problem.pre_apply_bcs:
for dbc in problem.dirichlet_bcs():
dbc.apply(problem.u_restrict)

if bounds is not None:
lower, upper = bounds
Expand All @@ -327,6 +331,7 @@ def solve(self, bounds=None):
self._setup = True
if problem.restrict:
problem.u.interpolate(problem.u_restrict)

solving_utils.check_snes_convergence(self.snes)

# Grab the comm associated with the `_problem` and call PETSc's garbage cleanup routine
Expand All @@ -340,7 +345,7 @@ class LinearVariationalProblem(NonlinearVariationalProblem):
@PETSc.Log.EventDecorator()
def __init__(self, a, L, u, bcs=None, aP=None,
form_compiler_parameters=None,
constant_jacobian=False, restrict=False):
constant_jacobian=False, restrict=False, pre_apply_bcs=True):
r"""
:param a: the bilinear form
:param L: the linear form
Expand All @@ -358,6 +363,8 @@ def __init__(self, a, L, u, bcs=None, aP=None,
:param restrict: (optional) If `True`, use restricted function spaces,
that exclude Dirichlet boundary condition nodes, internally for
the test and trial spaces.
:param pre_apply_bcs: (optional) If `False`, the problem is linearised
around the initial guess before imposing the boundary conditions.
"""
# In the linear case, the Jacobian is the equation LHS.
J = a
Expand All @@ -373,7 +380,7 @@ def __init__(self, a, L, u, bcs=None, aP=None,

super(LinearVariationalProblem, self).__init__(F, u, bcs, J, aP,
form_compiler_parameters=form_compiler_parameters,
is_linear=True, restrict=restrict)
is_linear=True, restrict=restrict, pre_apply_bcs=pre_apply_bcs)
self._constant_jacobian = constant_jacobian


Expand Down
Loading

0 comments on commit 72bfb6e

Please sign in to comment.