Skip to content

Commit

Permalink
ConstantValue: Support general dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jun 6, 2024
1 parent ea144b0 commit 1905eaa
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
8 changes: 8 additions & 0 deletions test/test_literals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__authors__ = "Martin Sandve Alnæs"
__date__ = "2011-04-14"

import numpy as np

from ufl import PermutationSymbol, as_matrix, as_vector, indices, product
from ufl.classes import Indexed
from ufl.constantvalue import ComplexValue, FloatValue, IntValue, Zero, as_ufl
Expand Down Expand Up @@ -29,13 +31,15 @@ def test_float(self):
f4 = FloatValue(1.0)
f5 = 3 - FloatValue(1) - 1
f6 = 3 * FloatValue(2) / 6
f7 = as_ufl(np.ones((1,), dtype="d")[0])

assert f1 == f1
self.assertNotEqual(f1, f2) # IntValue vs FloatValue, == compares representations!
assert f2 == f3
assert f2 == f4
assert f2 == f5
assert f2 == f6
assert f2 == f7


def test_int(self):
Expand All @@ -45,13 +49,15 @@ def test_int(self):
f4 = IntValue(1.0)
f5 = 3 - IntValue(1) - 1
f6 = 3 * IntValue(2) / 6
f7 = as_ufl(np.ones((1,), dtype="int")[0])

assert f1 == f1
self.assertNotEqual(f1, f2) # IntValue vs FloatValue, == compares representations!
assert f1 == f3
assert f1 == f4
assert f1 == f5
assert f2 == f6 # Division produces a FloatValue
assert f1 == f7


def test_complex(self):
Expand All @@ -62,6 +68,7 @@ def test_complex(self):
f5 = ComplexValue(1.0 + 1.0j)
f6 = as_ufl(1.0)
f7 = as_ufl(1.0j)
f8 = as_ufl(np.array([1 + 1j], dtype="complex")[0])

assert f1 == f1
assert f1 == f4
Expand All @@ -71,6 +78,7 @@ def test_complex(self):
assert f5 == f2 + f3
assert f4 == f5
assert f6 + f7 == f2 + f3
assert f4 == f8


def test_scalar_sums(self):
Expand Down
11 changes: 6 additions & 5 deletions ufl/constantvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Modified by Anders Logg, 2011.
# Modified by Massimiliano Leoni, 2016.

import numbers
from math import atan2

import ufl
Expand Down Expand Up @@ -506,12 +507,12 @@ def as_ufl(expression):
"""Converts expression to an Expr if possible."""
if isinstance(expression, (Expr, ufl.BaseForm)):
return expression
elif isinstance(expression, complex):
return ComplexValue(expression)
elif isinstance(expression, float):
return FloatValue(expression)
elif isinstance(expression, int):
elif isinstance(expression, numbers.Integral):
return IntValue(expression)
elif isinstance(expression, numbers.Real):
return FloatValue(expression)
elif isinstance(expression, numbers.Complex):
return ComplexValue(expression)
else:
raise ValueError(
f"Invalid type conversion: {expression} can not be converted to any UFL type."
Expand Down

0 comments on commit 1905eaa

Please sign in to comment.