Skip to content

Commit

Permalink
Delay Dirichlet Lifting for LVP, MG, and Fieldsplit
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 29, 2025
1 parent 956bb72 commit 20985af
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 92 deletions.
28 changes: 13 additions & 15 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def extract_form(self, form_type):
# DirichletBC is directly used in assembly.
return self

def _as_nonlinear_variational_problem_arg(self):
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
return self


Expand Down Expand Up @@ -500,16 +500,16 @@ def __init__(self, *args, bcs=None, J=None, Jp=None, V=None, is_linear=False, Jp

# linear
if isinstance(eq.lhs, ufl.Form) and isinstance(eq.rhs, ufl.Form):
J = eq.lhs
J, L = eq.lhs, eq.rhs
Jp = Jp or J
if eq.rhs == 0:
if L == 0 or L.empty():
F = ufl_expr.action(J, u)
else:
if not isinstance(eq.rhs, (ufl.Form, slate.slate.TensorBase)):
raise TypeError("Provided BC RHS is a '%s', not a Form or Slate Tensor" % type(eq.rhs).__name__)
if len(eq.rhs.arguments()) != 1:
if not isinstance(L, (ufl.BaseForm, slate.slate.TensorBase)):
raise TypeError("Provided BC RHS is a '%s', not a BaseForm or Slate Tensor" % type(L).__name__)
if len(L.arguments()) != 1:
raise ValueError("Provided BC RHS is not a linear form")
F = ufl_expr.action(J, u) - eq.rhs
F = ufl_expr.action(J, u) - L
self.is_linear = True
# nonlinear
else:
Expand All @@ -531,9 +531,7 @@ def __init__(self, *args, bcs=None, J=None, Jp=None, V=None, is_linear=False, Jp
# reconstruction for splitting `solving_utils.split`
self.Jp_eq_J = Jp_eq_J
self.is_linear = is_linear
self._F = args[0]
self._J = args[1]
self._Jp = args[2]
self._F, self._J, self._Jp = args[:3]
else:
raise TypeError("Wrong EquationBC arguments")

Expand Down Expand Up @@ -562,7 +560,7 @@ def reconstruct(self, V, subu, u, field, is_linear):
if all([_F is not None, _J is not None, _Jp is not None]):
return EquationBC(_F, _J, _Jp, Jp_eq_J=self.Jp_eq_J, is_linear=is_linear)

def _as_nonlinear_variational_problem_arg(self):
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
return self


Expand Down Expand Up @@ -654,19 +652,19 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col
ebc.add(bc_temp)
return ebc

def _as_nonlinear_variational_problem_arg(self):
def _as_nonlinear_variational_problem_arg(self, is_linear=False):
# NonlinearVariationalProblem expects EquationBC, not EquationBCSplit.
# -- This method is required when NonlinearVariationalProblem is constructed inside PC.
if len(self.f.arguments()) != 2:
raise NotImplementedError(f"Not expecting a form of rank {len(self.f.arguments())} (!= 2)")
J = self.f
Vcol = J.arguments()[-1].function_space()
u = firedrake.Function(Vcol)
F = ufl_expr.action(J, u)
Vrow = self._function_space
sub_domain = self.sub_domain
bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in self.bcs)
return EquationBC(F == 0, u, sub_domain, bcs=bcs, J=J, V=Vrow)
bcs = tuple(bc._as_nonlinear_variational_problem_arg(is_linear=is_linear) for bc in self.bcs)
equation = J == ufl.Form([]) if is_linear else ufl_expr.action(J, u) == 0
return EquationBC(equation, u, sub_domain, bcs=bcs, J=J, V=Vrow)


@PETSc.Log.EventDecorator()
Expand Down
11 changes: 6 additions & 5 deletions firedrake/mg/ufl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse):
# Apply bcs and also inject them
for bc in chain(*finectx._problem.bcs):
if isinstance(bc, DirichletBC):
bc.apply(finectx._x)
if finectx.pre_apply_bcs:
bc.apply(finectx._x)
g = bc.function_arg
if isinstance(g, firedrake.Function) and hasattr(g, "_child"):
manager.inject(g, g._child)
Expand Down Expand Up @@ -208,7 +209,7 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse):
F = self(problem.F, self, coefficient_mapping=coefficient_mapping)

fine = problem
problem = firedrake.NonlinearVariationalProblem(F, u, bcs=bcs, J=J, Jp=Jp,
problem = firedrake.NonlinearVariationalProblem(F, u, bcs=bcs, J=J, Jp=Jp, is_linear=problem.is_linear,
form_compiler_parameters=problem.form_compiler_parameters)
fine._coarse = problem
return problem
Expand Down Expand Up @@ -264,7 +265,8 @@ def coarsen_snescontext(context, self, coefficient_mapping=None):
mat_type=context.mat_type,
pmat_type=context.pmat_type,
appctx=new_appctx,
transfer_manager=context.transfer_manager)
transfer_manager=context.transfer_manager,
pre_apply_bcs=context.pre_apply_bcs)
coarse._fine = context
context._coarse = coarse

Expand Down Expand Up @@ -401,9 +403,8 @@ def create_injection(dmc, dmf):

cfn = firedrake.Function(V_c)
ffn = firedrake.Function(V_f)
cbcs = cctx._problem.bcs

ctx = Injection(cfn, ffn, manager, cbcs)
ctx = Injection(cfn, ffn, manager)
mat = PETSc.Mat().create(comm=dmc.comm)
mat.setSizes((row_size, col_size))
mat.setType(mat.Type.PYTHON)
Expand Down
16 changes: 6 additions & 10 deletions firedrake/preconditioners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,21 @@ def get_appctx(pc):
return get_appctx(pc.getDM()).appctx

@staticmethod
def new_snes_ctx(pc, op, bcs, mat_type, fcp=None, options_prefix=None):
def new_snes_ctx(pc, op, bcs, mat_type, fcp=None, options_prefix=None, pre_apply_bcs=True):
""" Create a new SNES contex for nested preconditioning
"""
from firedrake.variational_solver import NonlinearVariationalProblem
from firedrake.variational_solver import LinearVariationalProblem
from firedrake.function import Function
from firedrake.ufl_expr import action
from firedrake.solving_utils import _SNESContext

dm = pc.getDM()
old_appctx = get_appctx(dm).appctx
u = Function(op.arguments()[-1].function_space())
F = action(op, u)
L = 0
if bcs:
bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in bcs)
nprob = NonlinearVariationalProblem(F, u,
bcs=bcs,
J=op,
form_compiler_parameters=fcp)
return _SNESContext(nprob, mat_type, mat_type, old_appctx, options_prefix=options_prefix)
bcs = tuple(bc._as_nonlinear_variational_problem_arg(is_linear=True) for bc in bcs)
nprob = LinearVariationalProblem(op, L, u, bcs=bcs, form_compiler_parameters=fcp)
return _SNESContext(nprob, mat_type, mat_type, old_appctx, options_prefix=options_prefix, pre_apply_bcs=pre_apply_bcs)


class PCBase(PCSNESBase):
Expand Down
30 changes: 13 additions & 17 deletions firedrake/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,7 @@ def _solve_varproblem(*args, **kwargs):
problem = vs.LinearVariationalProblem(eq.lhs, eq.rhs, u, bcs, Jp,
form_compiler_parameters=form_compiler_parameters,
restrict=restrict)
# Create solver and call solve
solver = vs.LinearVariationalSolver(problem, solver_parameters=solver_parameters,
nullspace=nullspace,
transpose_nullspace=nullspace_T,
near_nullspace=near_nullspace,
options_prefix=options_prefix,
appctx=appctx)
solver.solve()
create_solver = vs.LinearVariationalSolver

# Solve nonlinear variational problem
else:
Expand All @@ -187,15 +180,18 @@ def _solve_varproblem(*args, **kwargs):
# Create problem
problem = vs.NonlinearVariationalProblem(eq.lhs, u, bcs, J, Jp,
form_compiler_parameters=form_compiler_parameters,
restrict=restrict, pre_apply_bcs=pre_apply_bcs)
# Create solver and call solve
solver = vs.NonlinearVariationalSolver(problem, solver_parameters=solver_parameters,
nullspace=nullspace,
transpose_nullspace=nullspace_T,
near_nullspace=near_nullspace,
options_prefix=options_prefix,
appctx=appctx)
solver.solve()
restrict=restrict)
create_solver = vs.NonlinearVariationalSolver

# Create solver and call solve
solver = create_solver(problem, solver_parameters=solver_parameters,
nullspace=nullspace,
transpose_nullspace=nullspace_T,
near_nullspace=near_nullspace,
options_prefix=options_prefix,
appctx=appctx,
pre_apply_bcs=pre_apply_bcs)
solver.solve()


def _la_solve(A, x, b, **kwargs):
Expand Down
38 changes: 23 additions & 15 deletions firedrake/solving_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from itertools import chain

import numpy
import ufl

from pyop2 import op2
from firedrake import dmhooks
Expand All @@ -9,6 +10,7 @@
from firedrake.exceptions import ConvergenceError
from firedrake.petsc import PETSc, DEFAULT_KSP_PARAMETERS
from firedrake.formmanipulation import ExtractSubBlock
from firedrake.ufl_expr import action
from firedrake.utils import cached_property
from firedrake.logging import warning

Expand Down Expand Up @@ -153,6 +155,8 @@ class _SNESContext(object):
:arg options_prefix: The options prefix of the SNES.
:arg transfer_manager: Object that can transfer functions between
levels, typically a :class:`~.TransferManager`
:arg pre_apply_bcs: If `False`, the problem is linearised
around the initial guess before imposing the boundary conditions.
The idea here is that the SNES holds a shell DM which contains
this object as "user context". When the SNES calls back to the
Expand All @@ -165,14 +169,16 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None,
pre_jacobian_callback=None, pre_function_callback=None,
post_jacobian_callback=None, post_function_callback=None,
options_prefix=None,
transfer_manager=None):
transfer_manager=None,
pre_apply_bcs=True):
from firedrake.assemble import get_assembler

if pmat_type is None:
pmat_type = mat_type
self.mat_type = mat_type
self.pmat_type = pmat_type
self.options_prefix = options_prefix
self.pre_apply_bcs = pre_apply_bcs

matfree = mat_type == 'matfree'
pmatfree = pmat_type == 'matfree'
Expand Down Expand Up @@ -222,14 +228,17 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None,
self.bcs_Jp = tuple(bc.extract_form('Jp') for bc in problem.bcs)

self._bc_residual = None
if not problem.pre_apply_bcs:
if not pre_apply_bcs:
# Delayed lifting of DirichletBCs
self._bc_residual = Function(self._x.function_space())
self.F -= self.J * self._bc_residual
if problem.is_linear:
# Drop the existing BC lifting term in the residual
self.F = ufl.replace(self.F, {self._x: ufl.zero(self._x.ufl_shape)})
self.F -= action(self.J, self._bc_residual)

self._assemble_residual = get_assembler(self.F, bcs=self.bcs_F,
form_compiler_parameters=self.fcp,
zero_bc_nodes=problem.pre_apply_bcs,
zero_bc_nodes=pre_apply_bcs,
).assemble

self._jacobian_assembled = False
Expand Down Expand Up @@ -375,15 +384,16 @@ def split(self, fields):
if isinstance(bc, DirichletBC):
bc_temp = bc.reconstruct(field=field, V=V, g=bc.function_arg, sub_domain=bc.sub_domain)
elif isinstance(bc, EquationBC):
bc_temp = bc.reconstruct(V, subu, u, field, False)
bc_temp = bc.reconstruct(V, subu, u, field, problem.is_linear)
if bc_temp is not None:
bcs.append(bc_temp)
new_problem = NLVP(F, subu, bcs=bcs, J=J, Jp=Jp,
new_problem = NLVP(F, subu, bcs=bcs, J=J, Jp=Jp, is_linear=problem.is_linear,
form_compiler_parameters=problem.form_compiler_parameters)
new_problem._constant_jacobian = problem._constant_jacobian
splits.append(type(self)(new_problem, mat_type=self.mat_type, pmat_type=self.pmat_type,
appctx=self.appctx,
transfer_manager=self.transfer_manager))
transfer_manager=self.transfer_manager,
pre_apply_bcs=self.pre_apply_bcs))
return self._splits.setdefault(tuple(fields), splits)

@staticmethod
Expand All @@ -404,10 +414,11 @@ def form_function(snes, X, F):
if ctx._pre_function_callback is not None:
ctx._pre_function_callback(X)

if ctx._bc_residual is not None:
# Delayed lifting of DirichletBC
if not ctx.pre_apply_bcs:
# Compute DirichletBC residual
for bc in ctx.bcs_F:
bc.apply(ctx._bc_residual, u=ctx._x)

ctx._assemble_residual(tensor=ctx._F, current_state=ctx._x)

if ctx._post_function_callback is not None:
Expand Down Expand Up @@ -468,7 +479,6 @@ def compute_operators(ksp, J, P):
:arg J: the Jacobian (a Mat)
:arg P: the preconditioner matrix (a Mat)
"""
from firedrake.bcs import DirichletBC
dm = ksp.getDM()
ctx = dmhooks.get_appctx(dm)
problem = ctx._problem
Expand All @@ -483,11 +493,9 @@ def compute_operators(ksp, J, P):
if fine is not None:
manager = dmhooks.get_transfer_manager(fine._x.function_space().dm)
manager.inject(fine._x, ctx._x)

if ctx._problem.pre_apply_bcs:
for bc in chain(*ctx._problem.bcs):
if isinstance(bc, DirichletBC):
bc.apply(ctx._x)
if ctx.pre_apply_bcs:
for bc in problem.dirichlet_bcs():
bc.apply(ctx._x)

ctx._assemble_jac(ctx._jac)
if ctx.Jp is not None:
Expand Down
Loading

0 comments on commit 20985af

Please sign in to comment.