Skip to content

Commit

Permalink
FEniCS-style bcs
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 28, 2025
1 parent 58f3d6f commit 7226e77
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 36 deletions.
51 changes: 29 additions & 22 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def assemble(expr, *args, **kwargs):
if args:
raise RuntimeError(f"Got unexpected args: {args}")
tensor = kwargs.pop("tensor", None)
return get_assembler(expr, *args, **kwargs).assemble(tensor=tensor)
u = kwargs.pop("current_state", None)
return get_assembler(expr, *args, **kwargs).assemble(tensor=tensor, current_state=u)


def get_assembler(form, *args, **kwargs):
Expand Down Expand Up @@ -358,7 +359,7 @@ def allocation_integral_types(self):
else:
return self._allocation_integral_types

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the form.
Parameters
Expand Down Expand Up @@ -389,7 +390,7 @@ def visitor(e, *operands):
rank = len(self._form.arguments())
if rank == 1 and not isinstance(result, ufl.ZeroBaseForm):
for bc in self._bcs:
bc.zero(result)
OneFormAssembler._apply_bc(self, result, bc, u=current_state)

if tensor:
BaseFormAssembler.update_tensor(result, tensor)
Expand Down Expand Up @@ -968,7 +969,7 @@ def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters)
self._needs_zeroing = needs_zeroing

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
"""Assemble the form.
Parameters
Expand Down Expand Up @@ -998,12 +999,12 @@ def assemble(self, tensor=None):
self.execute_parloops(tensor)

for bc in self._bcs:
self._apply_bc(tensor, bc)
self._apply_bc(tensor, bc, u=current_state)

return self.result(tensor)

@abc.abstractmethod
def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
"""Apply boundary condition."""

@abc.abstractmethod
Expand Down Expand Up @@ -1138,7 +1139,7 @@ def allocate(self):
comm=self._form.ufl_domains()[0]._comm
)

def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
pass

def _check_tensor(self, tensor):
Expand Down Expand Up @@ -1199,26 +1200,31 @@ def allocate(self):
else:
raise RuntimeError(f"Not expected: found rank = {rank} and diagonal = {self._diagonal}")

def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
# TODO Maybe this could be a singledispatchmethod?
if isinstance(bc, DirichletBC):
self._apply_dirichlet_bc(tensor, bc)
if self._diagonal:
bc.set(tensor, self._weight)
elif self._zero_bc_nodes:
bc.zero(tensor)
elif u is not None:
assert self._weight == 1.0
tensor = tensor.riesz_representation("l2")
bc.apply(tensor, u=u)
else:
# NOTE this only works if tensor is a Function and not a Cofunction
bc.apply(tensor)
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight).assemble(tensor=tensor)
OneFormAssembler(bc.f, bcs=bc.bcs,
form_compiler_parameters=self._form_compiler_params,
needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes,
diagonal=self._diagonal,
weight=self._weight).assemble(tensor=tensor, current_state=u)
else:
raise AssertionError

def _apply_dirichlet_bc(self, tensor, bc):
if self._diagonal:
bc.set(tensor, self._weight)
elif not self._zero_bc_nodes:
# NOTE this only works if tensor is a Function and not a Cofunction
bc.apply(tensor)
else:
bc.zero(tensor)

def _check_tensor(self, tensor):
if tensor.function_space() != self._form.arguments()[0].function_space().dual():
raise ValueError("Form's argument does not match provided result tensor")
Expand Down Expand Up @@ -1430,7 +1436,8 @@ def _all_assemblers(self):
all_assemblers.extend(_assembler._all_assemblers)
return tuple(all_assemblers)

def _apply_bc(self, tensor, bc):
def _apply_bc(self, tensor, bc, u=None):
assert u is None
op2tensor = tensor.M
spaces = tuple(a.function_space() for a in tensor.a.arguments())
V = bc.function_space()
Expand Down Expand Up @@ -1534,7 +1541,7 @@ def allocate(self):
options_prefix=self._options_prefix,
appctx=self._appctx or {})

def assemble(self, tensor=None):
def assemble(self, tensor=None, current_state=None):
if tensor is None:
tensor = self.allocate()
else:
Expand Down
12 changes: 8 additions & 4 deletions firedrake/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def solve(*args, **kwargs):
To exclude Dirichlet boundary condition nodes through the use of a
:class`.RestrictedFunctionSpace`, set the ``restrict`` keyword
argument to be True.
To linearise around the initial guess before imposing boundary
conditions, set the ``pre_apply_bcs`` keyword argument to be False.
"""

assert len(args) > 0
Expand All @@ -151,7 +154,7 @@ def _solve_varproblem(*args, **kwargs):
eq, u, bcs, J, Jp, M, form_compiler_parameters, \
solver_parameters, nullspace, nullspace_T, \
near_nullspace, \
options_prefix, restrict = _extract_args(*args, **kwargs)
options_prefix, restrict, pre_apply_bcs = _extract_args(*args, **kwargs)

# Check whether solution is valid
if not isinstance(u, (function.Function, vector.Vector)):
Expand Down Expand Up @@ -184,7 +187,7 @@ def _solve_varproblem(*args, **kwargs):
# Create problem
problem = vs.NonlinearVariationalProblem(eq.lhs, u, bcs, J, Jp,
form_compiler_parameters=form_compiler_parameters,
restrict=restrict)
restrict=restrict, pre_apply_bcs=pre_apply_bcs)
# Create solver and call solve
solver = vs.NonlinearVariationalSolver(problem, solver_parameters=solver_parameters,
nullspace=nullspace,
Expand Down Expand Up @@ -297,7 +300,7 @@ def _extract_args(*args, **kwargs):
valid_kwargs = ["bcs", "J", "Jp", "M",
"form_compiler_parameters", "solver_parameters",
"nullspace", "transpose_nullspace", "near_nullspace",
"options_prefix", "appctx", "restrict"]
"options_prefix", "appctx", "restrict", "pre_apply_bcs"]
for kwarg in kwargs.keys():
if kwarg not in valid_kwargs:
raise RuntimeError("Illegal keyword argument '%s'; valid keywords \
Expand Down Expand Up @@ -341,10 +344,11 @@ def _extract_args(*args, **kwargs):
solver_parameters = kwargs.get("solver_parameters", {})
options_prefix = kwargs.get("options_prefix", None)
restrict = kwargs.get("restrict", False)
pre_apply_bcs = kwargs.get("pre_apply_bcs", True)

return eq, u, bcs, J, Jp, M, form_compiler_parameters, \
solver_parameters, nullspace, nullspace_T, near_nullspace, \
options_prefix, restrict
options_prefix, restrict, pre_apply_bcs


def _extract_bcs(bcs):
Expand Down
14 changes: 13 additions & 1 deletion firedrake/solving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,15 @@ def __init__(self, problem, mat_type, pmat_type, appctx=None,
self.bcs_J = tuple(bc.extract_form('J') for bc in problem.bcs)
self.bcs_Jp = tuple(bc.extract_form('Jp') for bc in problem.bcs)

self._g = None
if not problem.pre_apply_bcs:
# Delayed lifting of DirichletBCs
self._g = self._x.copy(deepcopy=True)
self.F -= self.J * self._g

self._assemble_residual = get_assembler(self.F, bcs=self.bcs_F,
form_compiler_parameters=self.fcp,
zero_bc_nodes=problem.pre_apply_bcs,
).assemble

self._jacobian_assembled = False
Expand Down Expand Up @@ -395,7 +402,12 @@ def form_function(snes, X, F):
if ctx._pre_function_callback is not None:
ctx._pre_function_callback(X)

ctx._assemble_residual(tensor=ctx._F)
if not ctx._problem.pre_apply_bcs:
# Delayed lifting of DirichletBC
ctx._g.zero()
for bc in ctx.bcs_F:
bc.apply(ctx._g, u=ctx._x)
ctx._assemble_residual(tensor=ctx._F, current_state=ctx._x)

if ctx._post_function_callback is not None:
with ctx._F.dat.vec as F_:
Expand Down
17 changes: 12 additions & 5 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class NonlinearVariationalProblem(NonlinearVariationalProblemMixin):
def __init__(self, F, u, bcs=None, J=None,
Jp=None,
form_compiler_parameters=None,
is_linear=False, restrict=False):
is_linear=False, restrict=False, pre_apply_bcs=True):
r"""
:param F: the nonlinear form
:param u: the :class:`.Function` to solve for
Expand All @@ -68,6 +68,8 @@ def __init__(self, F, u, bcs=None, J=None,
:param restrict: (optional) If `True`, use restricted function spaces,
that exclude Dirichlet boundary condition nodes, internally for
the test and trial spaces.
:param pre_apply_bcs: (optional) If `False`, the problem is linearised
around the initial guess before imposing the boundary conditions.
"""
V = u.function_space()
self.output_space = V
Expand All @@ -86,6 +88,7 @@ def __init__(self, F, u, bcs=None, J=None,
if isinstance(bc, EquationBC):
restrict = False
self.restrict = restrict
self.pre_apply_bcs = pre_apply_bcs

if restrict and bcs:
V_res = restricted_function_space(V, extract_subdomain_ids(bcs))
Expand Down Expand Up @@ -304,8 +307,9 @@ def solve(self, bounds=None):
problem_dms = [V.dm for V in utils.unique(chain.from_iterable(c.function_space() for c in coefficients)) if V.dm != solution_dm]
problem_dms.append(solution_dm)

for dbc in problem.dirichlet_bcs():
dbc.apply(problem.u_restrict)
if problem.pre_apply_bcs:
for dbc in problem.dirichlet_bcs():
dbc.apply(problem.u_restrict)

if bounds is not None:
lower, upper = bounds
Expand All @@ -327,6 +331,7 @@ def solve(self, bounds=None):
self._setup = True
if problem.restrict:
problem.u.interpolate(problem.u_restrict)

solving_utils.check_snes_convergence(self.snes)

# Grab the comm associated with the `_problem` and call PETSc's garbage cleanup routine
Expand All @@ -340,7 +345,7 @@ class LinearVariationalProblem(NonlinearVariationalProblem):
@PETSc.Log.EventDecorator()
def __init__(self, a, L, u, bcs=None, aP=None,
form_compiler_parameters=None,
constant_jacobian=False, restrict=False):
constant_jacobian=False, restrict=False, pre_apply_bcs=True):
r"""
:param a: the bilinear form
:param L: the linear form
Expand All @@ -358,6 +363,8 @@ def __init__(self, a, L, u, bcs=None, aP=None,
:param restrict: (optional) If `True`, use restricted function spaces,
that exclude Dirichlet boundary condition nodes, internally for
the test and trial spaces.
:param pre_apply_bcs: (optional) If `False`, the problem is linearised
around the initial guess before imposing the boundary conditions.
"""
# In the linear case, the Jacobian is the equation LHS.
J = a
Expand All @@ -373,7 +380,7 @@ def __init__(self, a, L, u, bcs=None, aP=None,

super(LinearVariationalProblem, self).__init__(F, u, bcs, J, aP,
form_compiler_parameters=form_compiler_parameters,
is_linear=True, restrict=restrict)
is_linear=True, restrict=restrict, pre_apply_bcs=pre_apply_bcs)
self._constant_jacobian = constant_jacobian


Expand Down
39 changes: 35 additions & 4 deletions tests/firedrake/regression/test_solving_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,12 @@ def test_constant_jacobian_lvs():
assert not (norm(assemble(out*5 - f)) < 2e-7)


def test_solve_cofunction_rhs():
mesh = UnitIntervalMesh(10)
@pytest.fixture
def mesh(request):
return UnitIntervalMesh(10)


def test_solve_cofunction_rhs(mesh):
V = FunctionSpace(mesh, "CG", 1)
x, = SpatialCoordinate(mesh)

Expand All @@ -242,8 +246,7 @@ def test_solve_cofunction_rhs():
assert np.allclose(L.dat.data, Lold.dat.data)


def test_solve_empty_form_rhs():
mesh = UnitIntervalMesh(10)
def test_solve_empty_form_rhs(mesh):
V = FunctionSpace(mesh, "CG", 1)

u = TrialFunction(V)
Expand All @@ -257,3 +260,31 @@ def test_solve_empty_form_rhs():
w = Function(V)
solve(a == L, w, bcs)
assert errornorm(x, w) < 1E-10


def test_solve_pre_apply_bcs(mesh):
V = VectorFunctionSpace(mesh, "CG", 1)
x = SpatialCoordinate(mesh)

eps = Constant(0.1)
g = -eps*x

bc = DirichletBC(V, g, "on_boundary")

u = Function(V)

lam = Constant(1E0)
dim = mesh.geometric_dimension()
F = grad(u) + Identity(dim)
J = det(F)
logJ = 0.5*ln(J**2)

W = (1/2)*(inner(F, F) - dim - 2*logJ + lam*logJ**2) * dx(degree=4)

F = derivative(W, u)
# Raises nans if pre_apply_bcs=True
solve(F == 0, u, bc, pre_apply_bcs=False,
solver_parameters={"snes_atol": 1E-8,
"snes_rtol": 1E-8,
"snes_linesearch_type": "basic",
"snes_monitor": None})

0 comments on commit 7226e77

Please sign in to comment.