Skip to content

Commit

Permalink
Delay Dirichlet Lifting for LVP, MG, and Fieldsplit
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 29, 2025
1 parent 956bb72 commit 82cd4e1
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 75 deletions.
28 changes: 13 additions & 15 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def extract_form(self, form_type):
# DirichletBC is directly used in assembly.
return self

def _as_nonlinear_variational_problem_arg(self):
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
return self


Expand Down Expand Up @@ -500,16 +500,16 @@ def __init__(self, *args, bcs=None, J=None, Jp=None, V=None, is_linear=False, Jp

# linear
if isinstance(eq.lhs, ufl.Form) and isinstance(eq.rhs, ufl.Form):
J = eq.lhs
J, L = eq.lhs, eq.rhs
Jp = Jp or J
if eq.rhs == 0:
if eq.rhs == 0 or eq.rhs.empty():
F = ufl_expr.action(J, u)
else:
if not isinstance(eq.rhs, (ufl.Form, slate.slate.TensorBase)):
raise TypeError("Provided BC RHS is a '%s', not a Form or Slate Tensor" % type(eq.rhs).__name__)
if len(eq.rhs.arguments()) != 1:
if not isinstance(L, (ufl.BaseForm, slate.slate.TensorBase)):
raise TypeError("Provided BC RHS is a '%s', not a BaseForm or Slate Tensor" % type(L).__name__)
if len(L.arguments()) != 1 and not L.empty():
raise ValueError("Provided BC RHS is not a linear form")
F = ufl_expr.action(J, u) - eq.rhs
F = ufl_expr.action(J, u) - L
self.is_linear = True
# nonlinear
else:
Expand All @@ -531,9 +531,7 @@ def __init__(self, *args, bcs=None, J=None, Jp=None, V=None, is_linear=False, Jp
# reconstruction for splitting `solving_utils.split`
self.Jp_eq_J = Jp_eq_J
self.is_linear = is_linear
self._F = args[0]
self._J = args[1]
self._Jp = args[2]
self._F, self._J, self._Jp = args[:3]
else:
raise TypeError("Wrong EquationBC arguments")

Expand Down Expand Up @@ -562,7 +560,7 @@ def reconstruct(self, V, subu, u, field, is_linear):
if all([_F is not None, _J is not None, _Jp is not None]):
return EquationBC(_F, _J, _Jp, Jp_eq_J=self.Jp_eq_J, is_linear=is_linear)

def _as_nonlinear_variational_problem_arg(self):
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
return self


Expand Down Expand Up @@ -654,19 +652,19 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col
ebc.add(bc_temp)
return ebc

def _as_nonlinear_variational_problem_arg(self):
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
# NonlinearVariationalProblem expects EquationBC, not EquationBCSplit.
# -- This method is required when NonlinearVariationalProblem is constructed inside PC.
if len(self.f.arguments()) != 2:
raise NotImplementedError(f"Not expecting a form of rank {len(self.f.arguments())} (!= 2)")
J = self.f
Vcol = J.arguments()[-1].function_space()
u = firedrake.Function(Vcol)
F = ufl_expr.action(J, u)
Vrow = self._function_space
sub_domain = self.sub_domain
bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in self.bcs)
return EquationBC(F == 0, u, sub_domain, bcs=bcs, J=J, V=Vrow)
bcs = tuple(bc._as_nonlinear_variational_problem_arg(is_linear=is_linear) for bc in self.bcs)
equation = J == ufl.Form([]) if is_linear else ufl_expr.action(J, u) == 0
return EquationBC(equation, u, sub_domain, bcs=bcs, J=J, V=Vrow)


@PETSc.Log.EventDecorator()
Expand Down
8 changes: 5 additions & 3 deletions firedrake/mg/ufl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse):
# Apply bcs and also inject them
for bc in chain(*finectx._problem.bcs):
if isinstance(bc, DirichletBC):
bc.apply(finectx._x)
if finectx.pre_apply_bcs:
bc.apply(finectx._x)
g = bc.function_arg
if isinstance(g, firedrake.Function) and hasattr(g, "_child"):
manager.inject(g, g._child)
Expand Down Expand Up @@ -264,7 +265,8 @@ def coarsen_snescontext(context, self, coefficient_mapping=None):
mat_type=context.mat_type,
pmat_type=context.pmat_type,
appctx=new_appctx,
transfer_manager=context.transfer_manager)
transfer_manager=context.transfer_manager,
pre_apply_bcs=context.pre_apply_bcs)
coarse._fine = context
context._coarse = coarse

Expand Down Expand Up @@ -401,7 +403,7 @@ def create_injection(dmc, dmf):

cfn = firedrake.Function(V_c)
ffn = firedrake.Function(V_f)
cbcs = cctx._problem.bcs
cbcs = cctx._problem.bcs if cctx.pre_apply_bcs else None

ctx = Injection(cfn, ffn, manager, cbcs)
mat = PETSc.Mat().create(comm=dmc.comm)
Expand Down
16 changes: 6 additions & 10 deletions firedrake/preconditioners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,21 @@ def get_appctx(pc):
return get_appctx(pc.getDM()).appctx

@staticmethod
def new_snes_ctx(pc, op, bcs, mat_type, fcp=None, options_prefix=None):
def new_snes_ctx(pc, op, bcs, mat_type, fcp=None, options_prefix=None, pre_apply_bcs=True):
""" Create a new SNES contex for nested preconditioning
"""
from firedrake.variational_solver import NonlinearVariationalProblem
from firedrake.variational_solver import LinearVariationalProblem
from firedrake.function import Function
from firedrake.ufl_expr import action
from firedrake.solving_utils import _SNESContext

dm = pc.getDM()
old_appctx = get_appctx(dm).appctx
u = Function(op.arguments()[-1].function_space())
F = action(op, u)
L = 0
if bcs:
bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in bcs)
nprob = NonlinearVariationalProblem(F, u,
bcs=bcs,
J=op,
form_compiler_parameters=fcp)
return _SNESContext(nprob, mat_type, mat_type, old_appctx, options_prefix=options_prefix)
bcs = tuple(bc._as_nonlinear_variational_problem_arg(is_linear=True) for bc in bcs)
nprob = LinearVariationalProblem(op, L, u, bcs=bcs, form_compiler_parameters=fcp)
return _SNESContext(nprob, mat_type, mat_type, old_appctx, options_prefix=options_prefix, pre_apply_bcs=pre_apply_bcs)


class PCBase(PCSNESBase):
Expand Down
30 changes: 13 additions & 17 deletions firedrake/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,7 @@ def _solve_varproblem(*args, **kwargs):
problem = vs.LinearVariationalProblem(eq.lhs, eq.rhs, u, bcs, Jp,
form_compiler_parameters=form_compiler_parameters,
restrict=restrict)
# Create solver and call solve
solver = vs.LinearVariationalSolver(problem, solver_parameters=solver_parameters,
nullspace=nullspace,
transpose_nullspace=nullspace_T,
near_nullspace=near_nullspace,
options_prefix=options_prefix,
appctx=appctx)
solver.solve()
create_solver = vs.LinearVariationalSolver

# Solve nonlinear variational problem
else:
Expand All @@ -187,15 +180,18 @@ def _solve_varproblem(*args, **kwargs):
# Create problem
problem = vs.NonlinearVariationalProblem(eq.lhs, u, bcs, J, Jp,
form_compiler_parameters=form_compiler_parameters,
restrict=restrict, pre_apply_bcs=pre_apply_bcs)
# Create solver and call solve
solver = vs.NonlinearVariationalSolver(problem, solver_parameters=solver_parameters,
nullspace=nullspace,
transpose_nullspace=nullspace_T,
near_nullspace=near_nullspace,
options_prefix=options_prefix,
appctx=appctx)
solver.solve()
restrict=restrict)
create_solver = vs.NonlinearVariationalSolver

# Create solver and call solve
solver = create_solver(problem, solver_parameters=solver_parameters,
nullspace=nullspace,
transpose_nullspace=nullspace_T,
near_nullspace=near_nullspace,
options_prefix=options_prefix,
appctx=appctx,
pre_apply_bcs=pre_apply_bcs)
solver.solve()


def _la_solve(A, x, b, **kwargs):
Expand Down
28 changes: 19 additions & 9 deletions firedrake/solving_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from itertools import chain

import numpy
import ufl

from pyop2 import op2
from firedrake import dmhooks
Expand All @@ -9,6 +10,7 @@
from firedrake.exceptions import ConvergenceError
from firedrake.petsc import PETSc, DEFAULT_KSP_PARAMETERS
from firedrake.formmanipulation import ExtractSubBlock
from firedrake.ufl_expr import action
from firedrake.utils import cached_property
from firedrake.logging import warning

Expand Down Expand Up @@ -153,6 +155,8 @@ class _SNESContext(object):
:arg options_prefix: The options prefix of the SNES.
:arg transfer_manager: Object that can transfer functions between
levels, typically a :class:`~.TransferManager`
:arg pre_apply_bcs: If `False`, the problem is linearised
around the initial guess before imposing the boundary conditions.
The idea here is that the SNES holds a shell DM which contains
this object as "user context". When the SNES calls back to the
Expand All @@ -165,14 +169,16 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None,
pre_jacobian_callback=None, pre_function_callback=None,
post_jacobian_callback=None, post_function_callback=None,
options_prefix=None,
transfer_manager=None):
transfer_manager=None,
pre_apply_bcs=True):
from firedrake.assemble import get_assembler

if pmat_type is None:
pmat_type = mat_type
self.mat_type = mat_type
self.pmat_type = pmat_type
self.options_prefix = options_prefix
self.pre_apply_bcs = pre_apply_bcs

matfree = mat_type == 'matfree'
pmatfree = pmat_type == 'matfree'
Expand Down Expand Up @@ -222,14 +228,17 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None,
self.bcs_Jp = tuple(bc.extract_form('Jp') for bc in problem.bcs)

self._bc_residual = None
if not problem.pre_apply_bcs:
if not pre_apply_bcs:
# Delayed lifting of DirichletBCs
self._bc_residual = Function(self._x.function_space())
self.F -= self.J * self._bc_residual
if problem.is_linear:
# Drop the existing BC lifting term in the residual
self.F = ufl.replace(self.F, {self._x: ufl.zero(self._x.ufl_shape)})
self.F -= action(self.J, self._bc_residual)

self._assemble_residual = get_assembler(self.F, bcs=self.bcs_F,
form_compiler_parameters=self.fcp,
zero_bc_nodes=problem.pre_apply_bcs,
zero_bc_nodes=pre_apply_bcs,
).assemble

self._jacobian_assembled = False
Expand Down Expand Up @@ -383,7 +392,8 @@ def split(self, fields):
new_problem._constant_jacobian = problem._constant_jacobian
splits.append(type(self)(new_problem, mat_type=self.mat_type, pmat_type=self.pmat_type,
appctx=self.appctx,
transfer_manager=self.transfer_manager))
transfer_manager=self.transfer_manager,
pre_apply_bcs=self.pre_apply_bcs))
return self._splits.setdefault(tuple(fields), splits)

@staticmethod
Expand All @@ -404,10 +414,11 @@ def form_function(snes, X, F):
if ctx._pre_function_callback is not None:
ctx._pre_function_callback(X)

if ctx._bc_residual is not None:
# Delayed lifting of DirichletBC
if not ctx.pre_apply_bcs:
# Compute DirichletBC residual
for bc in ctx.bcs_F:
bc.apply(ctx._bc_residual, u=ctx._x)

ctx._assemble_residual(tensor=ctx._F, current_state=ctx._x)

if ctx._post_function_callback is not None:
Expand Down Expand Up @@ -483,8 +494,7 @@ def compute_operators(ksp, J, P):
if fine is not None:
manager = dmhooks.get_transfer_manager(fine._x.function_space().dm)
manager.inject(fine._x, ctx._x)

if ctx._problem.pre_apply_bcs:
if ctx.pre_apply_bcs:
for bc in chain(*ctx._problem.bcs):
if isinstance(bc, DirichletBC):
bc.apply(ctx._x)
Expand Down
21 changes: 10 additions & 11 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class NonlinearVariationalProblem(NonlinearVariationalProblemMixin):
def __init__(self, F, u, bcs=None, J=None,
Jp=None,
form_compiler_parameters=None,
is_linear=False, restrict=False, pre_apply_bcs=True):
is_linear=False, restrict=False):
r"""
:param F: the nonlinear form
:param u: the :class:`.Function` to solve for
Expand All @@ -68,8 +68,6 @@ def __init__(self, F, u, bcs=None, J=None,
:param restrict: (optional) If `True`, use restricted function spaces,
that exclude Dirichlet boundary condition nodes, internally for
the test and trial spaces.
:param pre_apply_bcs: (optional) If `False`, the problem is linearised
around the initial guess before imposing the boundary conditions.
"""
V = u.function_space()
self.output_space = V
Expand All @@ -88,7 +86,6 @@ def __init__(self, F, u, bcs=None, J=None,
if isinstance(bc, EquationBC):
restrict = False
self.restrict = restrict
self.pre_apply_bcs = pre_apply_bcs

if restrict and bcs:
V_res = restricted_function_space(V, extract_subdomain_ids(bcs))
Expand Down Expand Up @@ -156,7 +153,8 @@ def __init__(self, problem, *, solver_parameters=None,
pre_jacobian_callback=None,
post_jacobian_callback=None,
pre_function_callback=None,
post_function_callback=None):
post_function_callback=None,
pre_apply_bcs=True):
r"""
:arg problem: A :class:`NonlinearVariationalProblem` to solve.
:kwarg nullspace: an optional :class:`.VectorSpaceBasis` (or
Expand Down Expand Up @@ -185,6 +183,8 @@ def __init__(self, problem, *, solver_parameters=None,
before residual assembly.
:kwarg post_function_callback: As above, but called immediately
after residual assembly.
:kwarg pre_apply_bcs: If `False`, the problem is linearised
around the initial guess before imposing the boundary conditions.
Example usage of the ``solver_parameters`` option: to set the
nonlinear solver type to just use a linear solver, use
Expand Down Expand Up @@ -236,7 +236,8 @@ def update_diffusivity(current_solution):
pre_function_callback=pre_function_callback,
post_jacobian_callback=post_jacobian_callback,
post_function_callback=post_function_callback,
options_prefix=self.options_prefix)
options_prefix=self.options_prefix,
pre_apply_bcs=pre_apply_bcs)

self.snes = PETSc.SNES().create(comm=problem.dm.comm)

Expand Down Expand Up @@ -307,7 +308,7 @@ def solve(self, bounds=None):
problem_dms = [V.dm for V in utils.unique(chain.from_iterable(c.function_space() for c in coefficients)) if V.dm != solution_dm]
problem_dms.append(solution_dm)

if problem.pre_apply_bcs:
if self._ctx.pre_apply_bcs:
for dbc in problem.dirichlet_bcs():
dbc.apply(problem.u_restrict)

Expand Down Expand Up @@ -344,7 +345,7 @@ class LinearVariationalProblem(NonlinearVariationalProblem):
@PETSc.Log.EventDecorator()
def __init__(self, a, L, u, bcs=None, aP=None,
form_compiler_parameters=None,
constant_jacobian=False, restrict=False, pre_apply_bcs=True):
constant_jacobian=False, restrict=False):
r"""
:param a: the bilinear form
:param L: the linear form
Expand All @@ -362,8 +363,6 @@ def __init__(self, a, L, u, bcs=None, aP=None,
:param restrict: (optional) If `True`, use restricted function spaces,
that exclude Dirichlet boundary condition nodes, internally for
the test and trial spaces.
:param pre_apply_bcs: (optional) If `False`, the problem is linearised
around the initial guess before imposing the boundary conditions.
"""
# In the linear case, the Jacobian is the equation LHS.
J = a
Expand All @@ -379,7 +378,7 @@ def __init__(self, a, L, u, bcs=None, aP=None,

super(LinearVariationalProblem, self).__init__(F, u, bcs, J, aP,
form_compiler_parameters=form_compiler_parameters,
is_linear=True, restrict=restrict, pre_apply_bcs=pre_apply_bcs)
is_linear=True, restrict=restrict)
self._constant_jacobian = constant_jacobian


Expand Down
Loading

0 comments on commit 82cd4e1

Please sign in to comment.