Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constant and derivative #3261

Merged
merged 11 commits into from
Mar 6, 2024
42 changes: 42 additions & 0 deletions firedrake/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
from pyop2.exceptions import DataTypeError, DataValueError
from firedrake.petsc import PETSc
from firedrake.utils import ScalarType
from ufl.classes import all_ufl_classes, ufl_classes, terminal_classes
from ufl.core.ufl_type import UFLType
from ufl.corealg.multifunction import MultiFunction
from ufl.formatting.ufl2unicode import (
Expression2UnicodeHandler, UC, subscript_number, PrecedenceRules,
colorama,
)
from ufl.utils.counted import Counted


Expand Down Expand Up @@ -54,6 +61,8 @@ class Constant(ufl.constantvalue.ConstantValue, ConstantMixin, TSFCConstantMixin
:class:`~ufl.form.Form` on its own you need to pass a
:func:`~.Mesh` as the domain argument.
"""
_ufl_typecode_ = UFLType._ufl_num_typecodes_
_ufl_handler_name_ = "firedrake_constant"

def __new__(cls, value, domain=None, name=None, count=None):
if domain:
Expand Down Expand Up @@ -197,3 +206,36 @@ def __idiv__(self, o):

def __str__(self):
return str(self.dat.data_ro)


# Unicode handler for Firedrake constants
def _unicode_format_firedrake_constant(self, o):
"""Format a Firedrake constant."""
i = o.count()
var = "C"
if len(o.ufl_shape) == 1:
var += UC.combining_right_arrow_above
elif len(o.ufl_shape) > 1 and self.colorama_bold:
var = f"{colorama.Style.BRIGHT}{var}{colorama.Style.RESET_ALL}"
return f"{var}{subscript_number(i)}"


# This monkey patches ufl2unicode support for Firedrake constants
Expression2UnicodeHandler.firedrake_constant = _unicode_format_firedrake_constant

# This is internally done in UFL by the ufl_type decorator, but we cannot
# do the same here, because we want to use the class name Constant
UFLType._ufl_num_typecodes_ += 1
UFLType._ufl_all_classes_.append(Constant)
UFLType._ufl_all_handler_names_.add('firedrake_constant')
UFLType._ufl_obj_init_counts_.append(0)
UFLType._ufl_obj_del_counts_.append(0)

# And doing the above does not append to these magic UFL variables...
all_ufl_classes.add(Constant)
ufl_classes.add(Constant)
terminal_classes.add(Constant)

# These caches need rebuilding for the new type to be registered
MultiFunction._handlers_cache = {}
ufl.formatting.ufl2unicode._precrules = PrecedenceRules()
12 changes: 8 additions & 4 deletions firedrake/ufl_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,10 @@ def derivative(form, u, du=None, coefficient_derivatives=None):
# Replace instances of the constant with a new argument ``x``
# and differentiate wrt ``x``.
V = firedrake.FunctionSpace(mesh, "Real", 0)
x = ufl.Coefficient(V, n + 1)
n += 1
x = ufl.Coefficient(V)
# TODO: Update this line when https://github.com/FEniCS/ufl/issues/171 is fixed
form = ufl.replace(form, {u: x})
u = x
u_orig, u = u, x
else:
raise RuntimeError("Can't compute derivative for form")

Expand All @@ -283,7 +282,12 @@ def derivative(form, u, du=None, coefficient_derivatives=None):
raise ValueError("Shapes of u and du do not match.\n"
"If you passed an indexed part of split(u) into "
"derivative, you need to provide an appropriate du as well.")
return ufl.derivative(form, u, du, internal_coefficient_derivatives)
dform = ufl.derivative(form, u, du, internal_coefficient_derivatives)
if isinstance(uc, firedrake.Constant):
# If we replaced constants with ``x`` to differentiate,
# replace them back to the original symbolic constant
dform = ufl.replace(dform, {u: u_orig})
return dform


@PETSc.Log.EventDecorator()
Expand Down
41 changes: 41 additions & 0 deletions tests/regression/test_constant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from firedrake import *
from ufl.formatting.ufl2unicode import ufl2unicode
from ufl.classes import IntValue
import numpy as np
import pytest

Expand Down Expand Up @@ -239,3 +241,42 @@ class CustomConstant(Constant):

assert const2.count() == const1.count() + 1
assert const3.count() == const1.count() + 2


def test_derivative_wrt_constant():
mesh = UnitIntervalMesh(5)
V = FunctionSpace(mesh, "CG", 1)

u = TrialFunction(V)
v = TestFunction(V)
c = Constant(5)
f = Function(V).assign(7)
solution_a = Function(V)
solution_b = Function(V)

a = (c**2)*inner(u, v) * dx
L = inner(f, v) * dx
solve(a == L, solution_a)

d = derivative(a, c, IntValue(1))
solve(d == L, solution_b)

assert np.allclose(solution_b.dat.data, (c.dat.data/2)*solution_a.dat.data)
JDBetteridge marked this conversation as resolved.
Show resolved Hide resolved


def test_constant_ufl2unicode():
mesh = UnitIntervalMesh(1)
a = Constant(1.0, name="a")
b = Constant(2.0, name="b")
F = a * a * b * b * dx(mesh)
_ = ufl2unicode(F)

dFda = derivative(F, u=a)
dFdb = derivative(F, u=b)
_ = ufl2unicode(dFda)
_ = ufl2unicode(dFdb)

dFda_du = derivative(F, u=a, du=ufl.classes.IntValue(1))
dFdb_du = derivative(F, u=b, du=ufl.classes.IntValue(1))
_ = ufl2unicode(dFda_du)
_ = ufl2unicode(dFdb_du)
Loading