diff --git a/firedrake/bcs.py b/firedrake/bcs.py index 5884907feb..564dbf0a77 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -462,7 +462,7 @@ def extract_form(self, form_type): # DirichletBC is directly used in assembly. return self - def _as_nonlinear_variational_problem_arg(self): + def _as_nonlinear_variational_problem_arg(self, is_linear=False): return self @@ -500,16 +500,16 @@ def __init__(self, *args, bcs=None, J=None, Jp=None, V=None, is_linear=False, Jp # linear if isinstance(eq.lhs, ufl.Form) and isinstance(eq.rhs, ufl.Form): - J = eq.lhs + J, L = eq.lhs, eq.rhs Jp = Jp or J - if eq.rhs == 0: + if eq.rhs == 0 or eq.rhs.empty(): F = ufl_expr.action(J, u) else: - if not isinstance(eq.rhs, (ufl.Form, slate.slate.TensorBase)): - raise TypeError("Provided BC RHS is a '%s', not a Form or Slate Tensor" % type(eq.rhs).__name__) - if len(eq.rhs.arguments()) != 1: + if not isinstance(L, (ufl.BaseForm, slate.slate.TensorBase)): + raise TypeError("Provided BC RHS is a '%s', not a BaseForm or Slate Tensor" % type(L).__name__) + if len(L.arguments()) != 1 and not L.empty(): raise ValueError("Provided BC RHS is not a linear form") - F = ufl_expr.action(J, u) - eq.rhs + F = ufl_expr.action(J, u) - L self.is_linear = True # nonlinear else: @@ -531,9 +531,7 @@ def __init__(self, *args, bcs=None, J=None, Jp=None, V=None, is_linear=False, Jp # reconstruction for splitting `solving_utils.split` self.Jp_eq_J = Jp_eq_J self.is_linear = is_linear - self._F = args[0] - self._J = args[1] - self._Jp = args[2] + self._F, self._J, self._Jp = args[:3] else: raise TypeError("Wrong EquationBC arguments") @@ -562,7 +560,7 @@ def reconstruct(self, V, subu, u, field, is_linear): if all([_F is not None, _J is not None, _Jp is not None]): return EquationBC(_F, _J, _Jp, Jp_eq_J=self.Jp_eq_J, is_linear=is_linear) - def _as_nonlinear_variational_problem_arg(self): + def _as_nonlinear_variational_problem_arg(self, is_linear=False): return self @@ -654,7 +652,7 @@ def reconstruct(self, field=None, V=None, subu=None, u=None, row_field=None, col ebc.add(bc_temp) return ebc - def _as_nonlinear_variational_problem_arg(self): + def _as_nonlinear_variational_problem_arg(self, is_linear=False): # NonlinearVariationalProblem expects EquationBC, not EquationBCSplit. # -- This method is required when NonlinearVariationalProblem is constructed inside PC. if len(self.f.arguments()) != 2: @@ -662,11 +660,11 @@ def _as_nonlinear_variational_problem_arg(self): J = self.f Vcol = J.arguments()[-1].function_space() u = firedrake.Function(Vcol) - F = ufl_expr.action(J, u) Vrow = self._function_space sub_domain = self.sub_domain - bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in self.bcs) - return EquationBC(F == 0, u, sub_domain, bcs=bcs, J=J, V=Vrow) + bcs = tuple(bc._as_nonlinear_variational_problem_arg(is_linear=is_linear) for bc in self.bcs) + equation = J == ufl.Form([]) if is_linear else ufl_expr.action(J, u) == 0 + return EquationBC(equation, u, sub_domain, bcs=bcs, J=J, V=Vrow) @PETSc.Log.EventDecorator() diff --git a/firedrake/mg/ufl_utils.py b/firedrake/mg/ufl_utils.py index 0408a6d2a6..43763d1bde 100644 --- a/firedrake/mg/ufl_utils.py +++ b/firedrake/mg/ufl_utils.py @@ -180,7 +180,8 @@ def inject_on_restrict(fine, restriction, rscale, injection, coarse): # Apply bcs and also inject them for bc in chain(*finectx._problem.bcs): if isinstance(bc, DirichletBC): - bc.apply(finectx._x) + if finectx.pre_apply_bcs: + bc.apply(finectx._x) g = bc.function_arg if isinstance(g, firedrake.Function) and hasattr(g, "_child"): manager.inject(g, g._child) @@ -264,7 +265,8 @@ def coarsen_snescontext(context, self, coefficient_mapping=None): mat_type=context.mat_type, pmat_type=context.pmat_type, appctx=new_appctx, - transfer_manager=context.transfer_manager) + transfer_manager=context.transfer_manager, + pre_apply_bcs=context.pre_apply_bcs) coarse._fine = context context._coarse = coarse @@ -401,7 +403,7 @@ def create_injection(dmc, dmf): cfn = firedrake.Function(V_c) ffn = firedrake.Function(V_f) - cbcs = cctx._problem.bcs + cbcs = cctx._problem.bcs if cctx.pre_apply_bcs else None ctx = Injection(cfn, ffn, manager, cbcs) mat = PETSc.Mat().create(comm=dmc.comm) diff --git a/firedrake/preconditioners/base.py b/firedrake/preconditioners/base.py index 0bdfc97a37..ba72335a97 100644 --- a/firedrake/preconditioners/base.py +++ b/firedrake/preconditioners/base.py @@ -109,25 +109,21 @@ def get_appctx(pc): return get_appctx(pc.getDM()).appctx @staticmethod - def new_snes_ctx(pc, op, bcs, mat_type, fcp=None, options_prefix=None): + def new_snes_ctx(pc, op, bcs, mat_type, fcp=None, options_prefix=None, pre_apply_bcs=True): """ Create a new SNES contex for nested preconditioning """ - from firedrake.variational_solver import NonlinearVariationalProblem + from firedrake.variational_solver import LinearVariationalProblem from firedrake.function import Function - from firedrake.ufl_expr import action from firedrake.solving_utils import _SNESContext dm = pc.getDM() old_appctx = get_appctx(dm).appctx u = Function(op.arguments()[-1].function_space()) - F = action(op, u) + L = 0 if bcs: - bcs = tuple(bc._as_nonlinear_variational_problem_arg() for bc in bcs) - nprob = NonlinearVariationalProblem(F, u, - bcs=bcs, - J=op, - form_compiler_parameters=fcp) - return _SNESContext(nprob, mat_type, mat_type, old_appctx, options_prefix=options_prefix) + bcs = tuple(bc._as_nonlinear_variational_problem_arg(is_linear=True) for bc in bcs) + nprob = LinearVariationalProblem(op, L, u, bcs=bcs, form_compiler_parameters=fcp) + return _SNESContext(nprob, mat_type, mat_type, old_appctx, options_prefix=options_prefix, pre_apply_bcs=pre_apply_bcs) class PCBase(PCSNESBase): diff --git a/firedrake/solving.py b/firedrake/solving.py index ef80d2f4f1..6a2252c40e 100644 --- a/firedrake/solving.py +++ b/firedrake/solving.py @@ -171,14 +171,7 @@ def _solve_varproblem(*args, **kwargs): problem = vs.LinearVariationalProblem(eq.lhs, eq.rhs, u, bcs, Jp, form_compiler_parameters=form_compiler_parameters, restrict=restrict) - # Create solver and call solve - solver = vs.LinearVariationalSolver(problem, solver_parameters=solver_parameters, - nullspace=nullspace, - transpose_nullspace=nullspace_T, - near_nullspace=near_nullspace, - options_prefix=options_prefix, - appctx=appctx) - solver.solve() + create_solver = vs.LinearVariationalSolver # Solve nonlinear variational problem else: @@ -187,15 +180,18 @@ def _solve_varproblem(*args, **kwargs): # Create problem problem = vs.NonlinearVariationalProblem(eq.lhs, u, bcs, J, Jp, form_compiler_parameters=form_compiler_parameters, - restrict=restrict, pre_apply_bcs=pre_apply_bcs) - # Create solver and call solve - solver = vs.NonlinearVariationalSolver(problem, solver_parameters=solver_parameters, - nullspace=nullspace, - transpose_nullspace=nullspace_T, - near_nullspace=near_nullspace, - options_prefix=options_prefix, - appctx=appctx) - solver.solve() + restrict=restrict) + create_solver = vs.NonlinearVariationalSolver + + # Create solver and call solve + solver = create_solver(problem, solver_parameters=solver_parameters, + nullspace=nullspace, + transpose_nullspace=nullspace_T, + near_nullspace=near_nullspace, + options_prefix=options_prefix, + appctx=appctx, + pre_apply_bcs=pre_apply_bcs) + solver.solve() def _la_solve(A, x, b, **kwargs): diff --git a/firedrake/solving_utils.py b/firedrake/solving_utils.py index b9e44fc69d..4dd408acf2 100644 --- a/firedrake/solving_utils.py +++ b/firedrake/solving_utils.py @@ -1,6 +1,7 @@ from itertools import chain import numpy +import ufl from pyop2 import op2 from firedrake import dmhooks @@ -9,6 +10,7 @@ from firedrake.exceptions import ConvergenceError from firedrake.petsc import PETSc, DEFAULT_KSP_PARAMETERS from firedrake.formmanipulation import ExtractSubBlock +from firedrake.ufl_expr import action from firedrake.utils import cached_property from firedrake.logging import warning @@ -153,6 +155,8 @@ class _SNESContext(object): :arg options_prefix: The options prefix of the SNES. :arg transfer_manager: Object that can transfer functions between levels, typically a :class:`~.TransferManager` + :arg pre_apply_bcs: If `False`, the problem is linearised + around the initial guess before imposing the boundary conditions. The idea here is that the SNES holds a shell DM which contains this object as "user context". When the SNES calls back to the @@ -165,7 +169,8 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None, pre_jacobian_callback=None, pre_function_callback=None, post_jacobian_callback=None, post_function_callback=None, options_prefix=None, - transfer_manager=None): + transfer_manager=None, + pre_apply_bcs=True): from firedrake.assemble import get_assembler if pmat_type is None: @@ -173,6 +178,7 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None, self.mat_type = mat_type self.pmat_type = pmat_type self.options_prefix = options_prefix + self.pre_apply_bcs = pre_apply_bcs matfree = mat_type == 'matfree' pmatfree = pmat_type == 'matfree' @@ -222,14 +228,17 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None, self.bcs_Jp = tuple(bc.extract_form('Jp') for bc in problem.bcs) self._bc_residual = None - if not problem.pre_apply_bcs: + if not pre_apply_bcs: # Delayed lifting of DirichletBCs self._bc_residual = Function(self._x.function_space()) - self.F -= self.J * self._bc_residual + if problem.is_linear: + # Drop the existing BC lifting term in the residual + self.F = ufl.replace(self.F, {self._x: ufl.zero(self._x.ufl_shape)}) + self.F -= action(self.J, self._bc_residual) self._assemble_residual = get_assembler(self.F, bcs=self.bcs_F, form_compiler_parameters=self.fcp, - zero_bc_nodes=problem.pre_apply_bcs, + zero_bc_nodes=pre_apply_bcs, ).assemble self._jacobian_assembled = False @@ -383,7 +392,8 @@ def split(self, fields): new_problem._constant_jacobian = problem._constant_jacobian splits.append(type(self)(new_problem, mat_type=self.mat_type, pmat_type=self.pmat_type, appctx=self.appctx, - transfer_manager=self.transfer_manager)) + transfer_manager=self.transfer_manager, + pre_apply_bcs=self.pre_apply_bcs)) return self._splits.setdefault(tuple(fields), splits) @staticmethod @@ -404,10 +414,11 @@ def form_function(snes, X, F): if ctx._pre_function_callback is not None: ctx._pre_function_callback(X) - if ctx._bc_residual is not None: - # Delayed lifting of DirichletBC + if not ctx.pre_apply_bcs: + # Compute DirichletBC residual for bc in ctx.bcs_F: bc.apply(ctx._bc_residual, u=ctx._x) + ctx._assemble_residual(tensor=ctx._F, current_state=ctx._x) if ctx._post_function_callback is not None: @@ -483,8 +494,7 @@ def compute_operators(ksp, J, P): if fine is not None: manager = dmhooks.get_transfer_manager(fine._x.function_space().dm) manager.inject(fine._x, ctx._x) - - if ctx._problem.pre_apply_bcs: + if ctx.pre_apply_bcs: for bc in chain(*ctx._problem.bcs): if isinstance(bc, DirichletBC): bc.apply(ctx._x) diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index fcd3c78bfe..ad1b0fdf62 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -52,7 +52,7 @@ class NonlinearVariationalProblem(NonlinearVariationalProblemMixin): def __init__(self, F, u, bcs=None, J=None, Jp=None, form_compiler_parameters=None, - is_linear=False, restrict=False, pre_apply_bcs=True): + is_linear=False, restrict=False): r""" :param F: the nonlinear form :param u: the :class:`.Function` to solve for @@ -68,8 +68,6 @@ def __init__(self, F, u, bcs=None, J=None, :param restrict: (optional) If `True`, use restricted function spaces, that exclude Dirichlet boundary condition nodes, internally for the test and trial spaces. - :param pre_apply_bcs: (optional) If `False`, the problem is linearised - around the initial guess before imposing the boundary conditions. """ V = u.function_space() self.output_space = V @@ -88,7 +86,6 @@ def __init__(self, F, u, bcs=None, J=None, if isinstance(bc, EquationBC): restrict = False self.restrict = restrict - self.pre_apply_bcs = pre_apply_bcs if restrict and bcs: V_res = restricted_function_space(V, extract_subdomain_ids(bcs)) @@ -156,7 +153,8 @@ def __init__(self, problem, *, solver_parameters=None, pre_jacobian_callback=None, post_jacobian_callback=None, pre_function_callback=None, - post_function_callback=None): + post_function_callback=None, + pre_apply_bcs=True): r""" :arg problem: A :class:`NonlinearVariationalProblem` to solve. :kwarg nullspace: an optional :class:`.VectorSpaceBasis` (or @@ -185,6 +183,8 @@ def __init__(self, problem, *, solver_parameters=None, before residual assembly. :kwarg post_function_callback: As above, but called immediately after residual assembly. + :kwarg pre_apply_bcs: If `False`, the problem is linearised + around the initial guess before imposing the boundary conditions. Example usage of the ``solver_parameters`` option: to set the nonlinear solver type to just use a linear solver, use @@ -236,7 +236,8 @@ def update_diffusivity(current_solution): pre_function_callback=pre_function_callback, post_jacobian_callback=post_jacobian_callback, post_function_callback=post_function_callback, - options_prefix=self.options_prefix) + options_prefix=self.options_prefix, + pre_apply_bcs=pre_apply_bcs) self.snes = PETSc.SNES().create(comm=problem.dm.comm) @@ -307,7 +308,7 @@ def solve(self, bounds=None): problem_dms = [V.dm for V in utils.unique(chain.from_iterable(c.function_space() for c in coefficients)) if V.dm != solution_dm] problem_dms.append(solution_dm) - if problem.pre_apply_bcs: + if self._ctx.pre_apply_bcs: for dbc in problem.dirichlet_bcs(): dbc.apply(problem.u_restrict) @@ -344,7 +345,7 @@ class LinearVariationalProblem(NonlinearVariationalProblem): @PETSc.Log.EventDecorator() def __init__(self, a, L, u, bcs=None, aP=None, form_compiler_parameters=None, - constant_jacobian=False, restrict=False, pre_apply_bcs=True): + constant_jacobian=False, restrict=False): r""" :param a: the bilinear form :param L: the linear form @@ -362,8 +363,6 @@ def __init__(self, a, L, u, bcs=None, aP=None, :param restrict: (optional) If `True`, use restricted function spaces, that exclude Dirichlet boundary condition nodes, internally for the test and trial spaces. - :param pre_apply_bcs: (optional) If `False`, the problem is linearised - around the initial guess before imposing the boundary conditions. """ # In the linear case, the Jacobian is the equation LHS. J = a @@ -379,7 +378,7 @@ def __init__(self, a, L, u, bcs=None, aP=None, super(LinearVariationalProblem, self).__init__(F, u, bcs, J, aP, form_compiler_parameters=form_compiler_parameters, - is_linear=True, restrict=restrict, pre_apply_bcs=pre_apply_bcs) + is_linear=True, restrict=restrict) self._constant_jacobian = constant_jacobian diff --git a/tests/firedrake/regression/test_solving_interface.py b/tests/firedrake/regression/test_solving_interface.py index 45344eb16e..6dcd4668ea 100644 --- a/tests/firedrake/regression/test_solving_interface.py +++ b/tests/firedrake/regression/test_solving_interface.py @@ -262,26 +262,40 @@ def test_solve_empty_form_rhs(mesh): assert errornorm(x, w) < 1E-10 +@pytest.mark.skipif(utils.complex_mode, reason="Differentiation of energy not defined in Complex.") def test_solve_pre_apply_bcs(mesh): + """Solve a 1D hyperelasticity problem with linear exact solution. + The default DirichletBC treatment would raise NaNs if the problem is + linearised around an initial guess that satisfies the bcs and is zero on the interior. + Here we test that we can linearise around an initial point that + does not satisfy the DirichletBCs by passing pre_apply_bcs=True.""" + V = VectorFunctionSpace(mesh, "CG", 1) - x = SpatialCoordinate(mesh) + u = Function(V) + # Boundary conditions eps = Constant(0.1) - g = -eps*x - - bc = DirichletBC(V, g, "on_boundary") - - u = Function(V) + x = SpatialCoordinate(mesh) + g = -eps * x + bcs = [DirichletBC(V, g, "on_boundary")] + # Hyperelastic energy functional lam = Constant(1E0) dim = mesh.geometric_dimension() F = grad(u) + Identity(dim) J = det(F) - logJ = 0.5*ln(J*conj(J)) - + logJ = 0.5*ln(J**2) W = (1/2)*(inner(F, F) - dim - 2*logJ + lam*logJ**2) * dx + # Raises NaNs if pre_apply_bcs=True F = derivative(W, u) - # Raises nans if pre_apply_bcs=True - solve(F == 0, u, bc, pre_apply_bcs=False) + u.zero() + solve(F == 0, u, bcs, pre_apply_bcs=False) + assert errornorm(g, u) < 1E-10 + + # Test that pre_apply_bcs=False works on a linear solver + a = derivative(F, u) + L = Form([]) + u.zero() + solve(a == L, u, bcs, pre_apply_bcs=False) assert errornorm(g, u) < 1E-10