From ec733f25853dc075250f0937078df82faf99f971 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 27 Jan 2025 16:54:43 +0000 Subject: [PATCH] FEniCS-style bcs --- firedrake/assemble.py | 48 ++++++++++++++++++--------------- firedrake/solving.py | 12 ++++++--- firedrake/solving_utils.py | 3 ++- firedrake/variational_solver.py | 16 +++++++---- 4 files changed, 48 insertions(+), 31 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index b093baf945..06f67bab3c 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -358,7 +358,7 @@ def allocation_integral_types(self): else: return self._allocation_integral_types - def assemble(self, tensor=None): + def assemble(self, tensor=None, current_state=None): """Assemble the form. Parameters @@ -389,7 +389,7 @@ def visitor(e, *operands): rank = len(self._form.arguments()) if rank == 1 and not isinstance(result, ufl.ZeroBaseForm): for bc in self._bcs: - bc.zero(result) + OneFormAssembler._apply_bc(self, result, bc, u=current_state) if tensor: BaseFormAssembler.update_tensor(result, tensor) @@ -968,7 +968,7 @@ def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing= super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters) self._needs_zeroing = needs_zeroing - def assemble(self, tensor=None): + def assemble(self, tensor=None, current_state=None): """Assemble the form. Parameters @@ -998,12 +998,12 @@ def assemble(self, tensor=None): self.execute_parloops(tensor) for bc in self._bcs: - self._apply_bc(tensor, bc) + self._apply_bc(tensor, bc, u=current_state) return self.result(tensor) @abc.abstractmethod - def _apply_bc(self, tensor, bc): + def _apply_bc(self, tensor, bc, u=None): """Apply boundary condition.""" @abc.abstractmethod @@ -1138,7 +1138,7 @@ def allocate(self): comm=self._form.ufl_domains()[0]._comm ) - def _apply_bc(self, tensor, bc): + def _apply_bc(self, tensor, bc, u=None): pass def _check_tensor(self, tensor): @@ -1199,26 +1199,31 @@ def allocate(self): else: raise RuntimeError(f"Not expected: found rank = {rank} and diagonal = {self._diagonal}") - def _apply_bc(self, tensor, bc): + def _apply_bc(self, tensor, bc, u=None): # TODO Maybe this could be a singledispatchmethod? if isinstance(bc, DirichletBC): - self._apply_dirichlet_bc(tensor, bc) + if self._diagonal: + bc.set(tensor, self._weight) + elif self._zero_bc_nodes: + bc.zero(tensor) + elif u is not None: + assert self._weight == 1.0 + tensor = tensor.riesz_representation("l2") + bc.apply(tensor, u=u) + else: + # NOTE this only works if tensor is a Function and not a Cofunction + bc.apply(tensor) elif isinstance(bc, EquationBCSplit): bc.zero(tensor) - type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight).assemble(tensor=tensor) + OneFormAssembler(bc.f, bcs=bc.bcs, + form_compiler_parameters=self._form_compiler_params, + needs_zeroing=False, + zero_bc_nodes=self._zero_bc_nodes, + diagonal=self._diagonal, + weight=self._weight).assemble(tensor=tensor, current_state=u) else: raise AssertionError - def _apply_dirichlet_bc(self, tensor, bc): - if self._diagonal: - bc.set(tensor, self._weight) - elif not self._zero_bc_nodes: - # NOTE this only works if tensor is a Function and not a Cofunction - bc.apply(tensor) - else: - bc.zero(tensor) - def _check_tensor(self, tensor): if tensor.function_space() != self._form.arguments()[0].function_space().dual(): raise ValueError("Form's argument does not match provided result tensor") @@ -1430,7 +1435,8 @@ def _all_assemblers(self): all_assemblers.extend(_assembler._all_assemblers) return tuple(all_assemblers) - def _apply_bc(self, tensor, bc): + def _apply_bc(self, tensor, bc, u=None): + assert u is None op2tensor = tensor.M spaces = tuple(a.function_space() for a in tensor.a.arguments()) V = bc.function_space() @@ -1534,7 +1540,7 @@ def allocate(self): options_prefix=self._options_prefix, appctx=self._appctx or {}) - def assemble(self, tensor=None): + def assemble(self, tensor=None, current_state=None): if tensor is None: tensor = self.allocate() else: diff --git a/firedrake/solving.py b/firedrake/solving.py index 8b82d64f01..ef80d2f4f1 100644 --- a/firedrake/solving.py +++ b/firedrake/solving.py @@ -133,6 +133,9 @@ def solve(*args, **kwargs): To exclude Dirichlet boundary condition nodes through the use of a :class`.RestrictedFunctionSpace`, set the ``restrict`` keyword argument to be True. + + To linearise around the initial guess before imposing boundary + conditions, set the ``pre_apply_bcs`` keyword argument to be False. """ assert len(args) > 0 @@ -151,7 +154,7 @@ def _solve_varproblem(*args, **kwargs): eq, u, bcs, J, Jp, M, form_compiler_parameters, \ solver_parameters, nullspace, nullspace_T, \ near_nullspace, \ - options_prefix, restrict = _extract_args(*args, **kwargs) + options_prefix, restrict, pre_apply_bcs = _extract_args(*args, **kwargs) # Check whether solution is valid if not isinstance(u, (function.Function, vector.Vector)): @@ -184,7 +187,7 @@ def _solve_varproblem(*args, **kwargs): # Create problem problem = vs.NonlinearVariationalProblem(eq.lhs, u, bcs, J, Jp, form_compiler_parameters=form_compiler_parameters, - restrict=restrict) + restrict=restrict, pre_apply_bcs=pre_apply_bcs) # Create solver and call solve solver = vs.NonlinearVariationalSolver(problem, solver_parameters=solver_parameters, nullspace=nullspace, @@ -297,7 +300,7 @@ def _extract_args(*args, **kwargs): valid_kwargs = ["bcs", "J", "Jp", "M", "form_compiler_parameters", "solver_parameters", "nullspace", "transpose_nullspace", "near_nullspace", - "options_prefix", "appctx", "restrict"] + "options_prefix", "appctx", "restrict", "pre_apply_bcs"] for kwarg in kwargs.keys(): if kwarg not in valid_kwargs: raise RuntimeError("Illegal keyword argument '%s'; valid keywords \ @@ -341,10 +344,11 @@ def _extract_args(*args, **kwargs): solver_parameters = kwargs.get("solver_parameters", {}) options_prefix = kwargs.get("options_prefix", None) restrict = kwargs.get("restrict", False) + pre_apply_bcs = kwargs.get("pre_apply_bcs", True) return eq, u, bcs, J, Jp, M, form_compiler_parameters, \ solver_parameters, nullspace, nullspace_T, near_nullspace, \ - options_prefix, restrict + options_prefix, restrict, pre_apply_bcs def _extract_bcs(bcs): diff --git a/firedrake/solving_utils.py b/firedrake/solving_utils.py index 8f03fbaf41..a5544a7867 100644 --- a/firedrake/solving_utils.py +++ b/firedrake/solving_utils.py @@ -221,6 +221,7 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None, self._assemble_residual = get_assembler(self.F, bcs=self.bcs_F, form_compiler_parameters=self.fcp, + zero_bc_nodes=problem.pre_apply_bcs, ).assemble self._jacobian_assembled = False @@ -395,7 +396,7 @@ def form_function(snes, X, F): if ctx._pre_function_callback is not None: ctx._pre_function_callback(X) - ctx._assemble_residual(tensor=ctx._F) + ctx._assemble_residual(tensor=ctx._F, current_state=ctx._x) if ctx._post_function_callback is not None: with ctx._F.dat.vec as F_: diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 3c8fc8b930..fcd3c78bfe 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -52,7 +52,7 @@ class NonlinearVariationalProblem(NonlinearVariationalProblemMixin): def __init__(self, F, u, bcs=None, J=None, Jp=None, form_compiler_parameters=None, - is_linear=False, restrict=False): + is_linear=False, restrict=False, pre_apply_bcs=True): r""" :param F: the nonlinear form :param u: the :class:`.Function` to solve for @@ -68,6 +68,8 @@ def __init__(self, F, u, bcs=None, J=None, :param restrict: (optional) If `True`, use restricted function spaces, that exclude Dirichlet boundary condition nodes, internally for the test and trial spaces. + :param pre_apply_bcs: (optional) If `False`, the problem is linearised + around the initial guess before imposing the boundary conditions. """ V = u.function_space() self.output_space = V @@ -86,6 +88,7 @@ def __init__(self, F, u, bcs=None, J=None, if isinstance(bc, EquationBC): restrict = False self.restrict = restrict + self.pre_apply_bcs = pre_apply_bcs if restrict and bcs: V_res = restricted_function_space(V, extract_subdomain_ids(bcs)) @@ -304,8 +307,9 @@ def solve(self, bounds=None): problem_dms = [V.dm for V in utils.unique(chain.from_iterable(c.function_space() for c in coefficients)) if V.dm != solution_dm] problem_dms.append(solution_dm) - for dbc in problem.dirichlet_bcs(): - dbc.apply(problem.u_restrict) + if problem.pre_apply_bcs: + for dbc in problem.dirichlet_bcs(): + dbc.apply(problem.u_restrict) if bounds is not None: lower, upper = bounds @@ -340,7 +344,7 @@ class LinearVariationalProblem(NonlinearVariationalProblem): @PETSc.Log.EventDecorator() def __init__(self, a, L, u, bcs=None, aP=None, form_compiler_parameters=None, - constant_jacobian=False, restrict=False): + constant_jacobian=False, restrict=False, pre_apply_bcs=True): r""" :param a: the bilinear form :param L: the linear form @@ -358,6 +362,8 @@ def __init__(self, a, L, u, bcs=None, aP=None, :param restrict: (optional) If `True`, use restricted function spaces, that exclude Dirichlet boundary condition nodes, internally for the test and trial spaces. + :param pre_apply_bcs: (optional) If `False`, the problem is linearised + around the initial guess before imposing the boundary conditions. """ # In the linear case, the Jacobian is the equation LHS. J = a @@ -373,7 +379,7 @@ def __init__(self, a, L, u, bcs=None, aP=None, super(LinearVariationalProblem, self).__init__(F, u, bcs, J, aP, form_compiler_parameters=form_compiler_parameters, - is_linear=True, restrict=restrict) + is_linear=True, restrict=restrict, pre_apply_bcs=pre_apply_bcs) self._constant_jacobian = constant_jacobian