Skip to content

Commit

Permalink
FEniCS-style bcs
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 28, 2025
1 parent 58f3d6f commit ec733f2
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 31 deletions.
48 changes: 27 additions & 21 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def allocation_integral_types(self):
else:
return self._allocation_integral_types

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the form.
Parameters
Expand Down Expand Up @@ -389,7 +389,7 @@ def visitor(e, *operands):
rank = len(self._form.arguments())
if rank == 1 and not isinstance(result, ufl.ZeroBaseForm):
for bc in self._bcs:
bc.zero(result)
OneFormAssembler._apply_bc(self, result, bc, u=current_state)

if tensor:
BaseFormAssembler.update_tensor(result, tensor)
Expand Down Expand Up @@ -968,7 +968,7 @@ def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters)
self._needs_zeroing = needs_zeroing

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the form.
Parameters
Expand Down Expand Up @@ -998,12 +998,12 @@ def assemble(self, tensor=None):
self.execute_parloops(tensor)

for bc in self._bcs:
self._apply_bc(tensor, bc)
self._apply_bc(tensor, bc, u=current_state)

return self.result(tensor)

@abc.abstractmethod
def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
"""Apply boundary condition."""

@abc.abstractmethod
Expand Down Expand Up @@ -1138,7 +1138,7 @@ def allocate(self):
comm=self._form.ufl_domains()[0]._comm
)

def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
pass

def _check_tensor(self, tensor):
Expand Down Expand Up @@ -1199,26 +1199,31 @@ def allocate(self):
else:
raise RuntimeError(f"Not expected: found rank = {rank} and diagonal = {self._diagonal}")

def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
# TODO Maybe this could be a singledispatchmethod?
if isinstance(bc, DirichletBC):
self._apply_dirichlet_bc(tensor, bc)
if self._diagonal:
bc.set(tensor, self._weight)
elif self._zero_bc_nodes:
bc.zero(tensor)
elif u is not None:
assert self._weight == 1.0
tensor = tensor.riesz_representation("l2")
bc.apply(tensor, u=u)
else:
# NOTE this only works if tensor is a Function and not a Cofunction
bc.apply(tensor)
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight).assemble(tensor=tensor)
OneFormAssembler(bc.f, bcs=bc.bcs,
form_compiler_parameters=self._form_compiler_params,
needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes,
diagonal=self._diagonal,
weight=self._weight).assemble(tensor=tensor, current_state=u)
else:
raise AssertionError

def _apply_dirichlet_bc(self, tensor, bc):
if self._diagonal:
bc.set(tensor, self._weight)
elif not self._zero_bc_nodes:
# NOTE this only works if tensor is a Function and not a Cofunction
bc.apply(tensor)
else:
bc.zero(tensor)

def _check_tensor(self, tensor):
if tensor.function_space() != self._form.arguments()[0].function_space().dual():
raise ValueError("Form's argument does not match provided result tensor")
Expand Down Expand Up @@ -1430,7 +1435,8 @@ def _all_assemblers(self):
all_assemblers.extend(_assembler._all_assemblers)
return tuple(all_assemblers)

def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
assert u is None
op2tensor = tensor.M
spaces = tuple(a.function_space() for a in tensor.a.arguments())
V = bc.function_space()
Expand Down Expand Up @@ -1534,7 +1540,7 @@ def allocate(self):
options_prefix=self._options_prefix,
appctx=self._appctx or {})

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
if tensor is None:
tensor = self.allocate()
else:
Expand Down
12 changes: 8 additions & 4 deletions firedrake/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def solve(*args, **kwargs):
To exclude Dirichlet boundary condition nodes through the use of a
:class`.RestrictedFunctionSpace`, set the ``restrict`` keyword
argument to be True.
To linearise around the initial guess before imposing boundary
conditions, set the ``pre_apply_bcs`` keyword argument to be False.
"""

assert len(args) > 0
Expand All @@ -151,7 +154,7 @@ def _solve_varproblem(*args, **kwargs):
eq, u, bcs, J, Jp, M, form_compiler_parameters, \
solver_parameters, nullspace, nullspace_T, \
near_nullspace, \
options_prefix, restrict = _extract_args(*args, **kwargs)
options_prefix, restrict, pre_apply_bcs = _extract_args(*args, **kwargs)

# Check whether solution is valid
if not isinstance(u, (function.Function, vector.Vector)):
Expand Down Expand Up @@ -184,7 +187,7 @@ def _solve_varproblem(*args, **kwargs):
# Create problem
problem = vs.NonlinearVariationalProblem(eq.lhs, u, bcs, J, Jp,
form_compiler_parameters=form_compiler_parameters,
restrict=restrict)
restrict=restrict, pre_apply_bcs=pre_apply_bcs)
# Create solver and call solve
solver = vs.NonlinearVariationalSolver(problem, solver_parameters=solver_parameters,
nullspace=nullspace,
Expand Down Expand Up @@ -297,7 +300,7 @@ def _extract_args(*args, **kwargs):
valid_kwargs = ["bcs", "J", "Jp", "M",
"form_compiler_parameters", "solver_parameters",
"nullspace", "transpose_nullspace", "near_nullspace",
"options_prefix", "appctx", "restrict"]
"options_prefix", "appctx", "restrict", "pre_apply_bcs"]
for kwarg in kwargs.keys():
if kwarg not in valid_kwargs:
raise RuntimeError("Illegal keyword argument '%s'; valid keywords \
Expand Down Expand Up @@ -341,10 +344,11 @@ def _extract_args(*args, **kwargs):
solver_parameters = kwargs.get("solver_parameters", {})
options_prefix = kwargs.get("options_prefix", None)
restrict = kwargs.get("restrict", False)
pre_apply_bcs = kwargs.get("pre_apply_bcs", True)

return eq, u, bcs, J, Jp, M, form_compiler_parameters, \
solver_parameters, nullspace, nullspace_T, near_nullspace, \
options_prefix, restrict
options_prefix, restrict, pre_apply_bcs


def _extract_bcs(bcs):
Expand Down
3 changes: 2 additions & 1 deletion firedrake/solving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None,

self._assemble_residual = get_assembler(self.F, bcs=self.bcs_F,
form_compiler_parameters=self.fcp,
zero_bc_nodes=problem.pre_apply_bcs,
).assemble

self._jacobian_assembled = False
Expand Down Expand Up @@ -395,7 +396,7 @@ def form_function(snes, X, F):
if ctx._pre_function_callback is not None:
ctx._pre_function_callback(X)

ctx._assemble_residual(tensor=ctx._F)
ctx._assemble_residual(tensor=ctx._F, current_state=ctx._x)

if ctx._post_function_callback is not None:
with ctx._F.dat.vec as F_:
Expand Down
16 changes: 11 additions & 5 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class NonlinearVariationalProblem(NonlinearVariationalProblemMixin):
def __init__(self, F, u, bcs=None, J=None,
Jp=None,
form_compiler_parameters=None,
is_linear=False, restrict=False):
is_linear=False, restrict=False, pre_apply_bcs=True):
r"""
:param F: the nonlinear form
:param u: the :class:`.Function` to solve for
Expand All @@ -68,6 +68,8 @@ def __init__(self, F, u, bcs=None, J=None,
:param restrict: (optional) If `True`, use restricted function spaces,
that exclude Dirichlet boundary condition nodes, internally for
the test and trial spaces.
:param pre_apply_bcs: (optional) If `False`, the problem is linearised
around the initial guess before imposing the boundary conditions.
"""
V = u.function_space()
self.output_space = V
Expand All @@ -86,6 +88,7 @@ def __init__(self, F, u, bcs=None, J=None,
if isinstance(bc, EquationBC):
restrict = False
self.restrict = restrict
self.pre_apply_bcs = pre_apply_bcs

if restrict and bcs:
V_res = restricted_function_space(V, extract_subdomain_ids(bcs))
Expand Down Expand Up @@ -304,8 +307,9 @@ def solve(self, bounds=None):
problem_dms = [V.dm for V in utils.unique(chain.from_iterable(c.function_space() for c in coefficients)) if V.dm != solution_dm]
problem_dms.append(solution_dm)

for dbc in problem.dirichlet_bcs():
dbc.apply(problem.u_restrict)
if problem.pre_apply_bcs:
for dbc in problem.dirichlet_bcs():
dbc.apply(problem.u_restrict)

if bounds is not None:
lower, upper = bounds
Expand Down Expand Up @@ -340,7 +344,7 @@ class LinearVariationalProblem(NonlinearVariationalProblem):
@PETSc.Log.EventDecorator()
def __init__(self, a, L, u, bcs=None, aP=None,
form_compiler_parameters=None,
constant_jacobian=False, restrict=False):
constant_jacobian=False, restrict=False, pre_apply_bcs=True):
r"""
:param a: the bilinear form
:param L: the linear form
Expand All @@ -358,6 +362,8 @@ def __init__(self, a, L, u, bcs=None, aP=None,
:param restrict: (optional) If `True`, use restricted function spaces,
that exclude Dirichlet boundary condition nodes, internally for
the test and trial spaces.
:param pre_apply_bcs: (optional) If `False`, the problem is linearised
around the initial guess before imposing the boundary conditions.
"""
# In the linear case, the Jacobian is the equation LHS.
J = a
Expand All @@ -373,7 +379,7 @@ def __init__(self, a, L, u, bcs=None, aP=None,

super(LinearVariationalProblem, self).__init__(F, u, bcs, J, aP,
form_compiler_parameters=form_compiler_parameters,
is_linear=True, restrict=restrict)
is_linear=True, restrict=restrict, pre_apply_bcs=pre_apply_bcs)
self._constant_jacobian = constant_jacobian


Expand Down

0 comments on commit ec733f2

Please sign in to comment.