Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace multifunction with class methods for apply_restrictions #274

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ufl/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ufl.differentiation import CoefficientDerivative
from ufl.form import BaseForm, Form, FormSum, ZeroBaseForm
from ufl.matrix import Matrix
from ufl.typing import Self

# --- The Action class represents the action of a numerical object that needs
# to be computed at assembly time ---
Expand Down Expand Up @@ -150,6 +151,13 @@ def __hash__(self):
self._hash = hash(("Action", hash(self._right), hash(self._left)))
return self._hash

def apply_restrictions(self, mapped_operands, side) -> Self:
"""Apply restrictions.

Propagates restrictions in a form towards the terminals.
"""
return Action(mapped_operands[0], mapped_operands[1])


def _check_function_spaces(left, right):
"""Check if the function spaces of left and right match."""
Expand Down
8 changes: 8 additions & 0 deletions ufl/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ufl.argument import Coargument
from ufl.core.ufl_type import ufl_type
from ufl.form import BaseForm, FormSum, ZeroBaseForm
from ufl.typing import Self

# --- The Adjoint class represents the adjoint of a numerical object that
# needs to be computed at assembly time ---
Expand Down Expand Up @@ -119,3 +120,10 @@ def __hash__(self):
if self._hash is None:
self._hash = hash(("Adjoint", hash(self._form)))
return self._hash

def apply_restrictions(self, mapped_operands, side) -> Self:
"""Apply restrictions.

Propagates restrictions in a form towards the terminals.
"""
return Adjoint(mapped_operands[0])
4 changes: 2 additions & 2 deletions ufl/algorithms/apply_algebra_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#
# Modified by Anders Logg, 2009-2010

from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms.map_integrands import map_integrand_dags_legacy
from ufl.classes import Conj, Grad, Product
from ufl.compound_expressions import cofactor_expr, determinant_expr, deviatoric_expr, inverse_expr
from ufl.core.multiindex import Index, indices
Expand Down Expand Up @@ -147,4 +147,4 @@ def c(i, j):

def apply_algebra_lowering(expr):
"""Expands high level compound operators to equivalent representations using basic operators."""
return map_integrand_dags(LowerCompoundAlgebra(), expr)
return map_integrand_dags_legacy(LowerCompoundAlgebra(), expr)
32 changes: 18 additions & 14 deletions ufl/algorithms/apply_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ufl.action import Action
from ufl.algorithms.analysis import extract_arguments
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms.map_integrands import map_integrand_dags_legacy
from ufl.algorithms.replace_derivative_nodes import replace_derivative_nodes
from ufl.argument import BaseArgument
from ufl.checks import is_cellwise_constant
Expand Down Expand Up @@ -46,7 +46,7 @@
from ufl.core.expr import ufl_err_str
from ufl.core.multiindex import FixedIndex, MultiIndex, indices
from ufl.core.terminal import Terminal
from ufl.corealg.map_dag import map_expr_dag
from ufl.corealg.map_dag import map_expr_dag_legacy
from ufl.corealg.multifunction import MultiFunction
from ufl.differentiation import (
BaseFormCoordinateDerivative,
Expand Down Expand Up @@ -1211,7 +1211,7 @@ def compute_gprimeterm(ngrads, vval, vcomp, wshape, wcomp):
def coordinate_derivative(self, o):
"""Differentiate a coordinate_derivative."""
o = o.ufl_operands
return CoordinateDerivative(map_expr_dag(self, o[0]), o[1], o[2], o[3])
return CoordinateDerivative(map_expr_dag_legacy(self, o[0]), o[1], o[2], o[3])

def base_form_operator(self, o, *dfs):
"""Differentiate a base_form_operator.
Expand Down Expand Up @@ -1364,20 +1364,20 @@ def grad(self, o, f):
"""Apply to a grad."""
rules = GradRuleset(o.ufl_shape[-1])
key = (GradRuleset, o.ufl_shape[-1])
return map_expr_dag(rules, f, vcache=self.vcaches[key], rcache=self.rcaches[key])
return map_expr_dag_legacy(rules, f, vcache=self.vcaches[key], rcache=self.rcaches[key])

def reference_grad(self, o, f):
"""Apply to a reference_grad."""
rules = ReferenceGradRuleset(o.ufl_shape[-1]) # FIXME: Look over this and test better.
key = (ReferenceGradRuleset, o.ufl_shape[-1])
return map_expr_dag(rules, f, vcache=self.vcaches[key], rcache=self.rcaches[key])
return map_expr_dag_legacy(rules, f, vcache=self.vcaches[key], rcache=self.rcaches[key])

def variable_derivative(self, o, f, dummy_v):
"""Apply to a variable_derivative."""
op = o.ufl_operands[1]
rules = VariableRuleset(op)
key = (VariableRuleset, op)
return map_expr_dag(rules, f, vcache=self.vcaches[key], rcache=self.rcaches[key])
return map_expr_dag_legacy(rules, f, vcache=self.vcaches[key], rcache=self.rcaches[key])

def coefficient_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
"""Apply to a coefficient_derivative."""
Expand All @@ -1389,7 +1389,9 @@ def coefficient_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
key = (GateauxDerivativeRuleset, w, v, cd)
# We need to go through the dag first to record the pending
# operations
mapped_expr = map_expr_dag(rules, f, vcache=self.vcaches[key], rcache=self.rcaches[key])
mapped_expr = map_expr_dag_legacy(
rules, f, vcache=self.vcaches[key], rcache=self.rcaches[key]
)
# Need to account for pending operations that have been stored
# in other integrands
self.pending_operations += pending_operations
Expand All @@ -1413,7 +1415,9 @@ def base_form_operator_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
arguments += (arg,)
return ZeroBaseForm(arguments)
# We need to go through the dag first to record the pending operations
mapped_expr = map_expr_dag(rules, f, vcache=self.vcaches[key], rcache=self.rcaches[key])
mapped_expr = map_expr_dag_legacy(
rules, f, vcache=self.vcaches[key], rcache=self.rcaches[key]
)

mapped_f = rules.coefficient(f)
if mapped_f != 0:
Expand All @@ -1429,7 +1433,7 @@ def coordinate_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
o_ = o.ufl_operands
key = (CoordinateDerivative, o_[0])
return CoordinateDerivative(
map_expr_dag(self, o_[0], vcache=self.vcaches[key], rcache=self.rcaches[key]),
map_expr_dag_legacy(self, o_[0], vcache=self.vcaches[key], rcache=self.rcaches[key]),
o_[1],
o_[2],
o_[3],
Expand All @@ -1440,7 +1444,7 @@ def base_form_coordinate_derivative(self, o, f, dummy_w, dummy_v, dummy_cd):
o_ = o.ufl_operands
key = (BaseFormCoordinateDerivative, o_[0])
return BaseFormCoordinateDerivative(
map_expr_dag(self, o_[0], vcache=self.vcaches[key], rcache=self.rcaches[key]),
map_expr_dag_legacy(self, o_[0], vcache=self.vcaches[key], rcache=self.rcaches[key]),
o_[1],
o_[2],
o_[3],
Expand Down Expand Up @@ -1561,7 +1565,7 @@ def apply_derivatives(expression):
# - 0.
# Example:
# → If derivative(F(u, N(u); v), u) was taken the following line would compute `∂F/∂u`.
dexpression_dvar = map_integrand_dags(rules, expression)
dexpression_dvar = map_integrand_dags_legacy(rules, expression)

# Get the recorded delayed operations
pending_operations = rules.pending_operations
Expand All @@ -1583,7 +1587,7 @@ def apply_derivatives(expression):
# -- Replace dexpr/dvar by dexpr/dN -- #
# We don't use `apply_derivatives` since the differentiation is
# done via `\partial` and not `d`.
dexpr_dN = map_integrand_dags(
dexpr_dN = map_integrand_dags_legacy(
rules, replace_derivative_nodes(expression, {var.ufl_operands[0]: N})
)
# -- Add the BaseFormOperatorDerivative node -- #
Expand Down Expand Up @@ -1767,10 +1771,10 @@ def coordinate_derivative(self, o, f, w, v, cd):
_, w, v, cd = o.ufl_operands
rules = CoordinateDerivativeRuleset(w, v, cd)
key = (CoordinateDerivativeRuleset, w, v, cd)
return map_expr_dag(rules, f, vcache=self.vcache[key], rcache=self.rcache[key])
return map_expr_dag_legacy(rules, f, vcache=self.vcache[key], rcache=self.rcache[key])


def apply_coordinate_derivatives(expression):
"""Apply coordinate derivatives to an expression."""
rules = CoordinateDerivativeRuleDispatcher()
return map_integrand_dags(rules, expression)
return map_integrand_dags_legacy(rules, expression)
4 changes: 2 additions & 2 deletions ufl/algorithms/apply_function_pullbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms.map_integrands import map_integrand_dags_legacy
from ufl.classes import ReferenceValue
from ufl.corealg.multifunction import MultiFunction, memoized_handler

Expand Down Expand Up @@ -58,4 +58,4 @@ def apply_function_pullbacks(expr):
Args:
expr: An Expression
"""
return map_integrand_dags(FunctionPullbackApplier(), expr)
return map_integrand_dags_legacy(FunctionPullbackApplier(), expr)
6 changes: 3 additions & 3 deletions ufl/algorithms/apply_geometry_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)
from ufl.compound_expressions import cross_expr, determinant_expr, inverse_expr
from ufl.core.multiindex import Index, indices
from ufl.corealg.map_dag import map_expr_dag
from ufl.corealg.map_dag import map_expr_dag_legacy
from ufl.corealg.multifunction import MultiFunction, memoized_handler
from ufl.domain import extract_unique_domain
from ufl.measure import custom_integral_types, point_integral_types
Expand Down Expand Up @@ -485,13 +485,13 @@ def apply_geometry_lowering(form, preserve_types=()):
preserve_types = set(preserve_types) | set(automatic_preserve_types)

mf = GeometryLoweringApplier(preserve_types)
newintegrand = map_expr_dag(mf, integral.integrand())
newintegrand = map_expr_dag_legacy(mf, integral.integrand())
return integral.reconstruct(integrand=newintegrand)

elif isinstance(form, Expr):
expr = form
mf = GeometryLoweringApplier(preserve_types)
return map_expr_dag(mf, expr)
return map_expr_dag_legacy(mf, expr)

else:
raise ValueError(f"Invalid type {form.__class__.__name__}")
Loading
Loading