Skip to content

Commit

Permalink
refactor arithmetic power
Browse files Browse the repository at this point in the history
  • Loading branch information
mmatera committed Jul 25, 2023
1 parent 2aa00f0 commit 78496ef
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 40 deletions.
81 changes: 60 additions & 21 deletions mathics/builtin/arithfns/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
"""


import sympy

from mathics.builtin.arithmetic import create_infix
from mathics.builtin.base import (
BinaryOperator,
Expand Down Expand Up @@ -45,7 +48,6 @@
Symbol,
SymbolDivide,
SymbolHoldForm,
SymbolNull,
SymbolPower,
SymbolTimes,
)
Expand All @@ -56,10 +58,17 @@
SymbolInfix,
SymbolLeft,
SymbolMinus,
SymbolOverflow,
SymbolPattern,
SymbolSequence,
)
from mathics.eval.arithmetic import eval_Plus, eval_Times
from mathics.eval.arithmetic import (
associate_powers,
eval_Exponential,
eval_Plus,
eval_Power_inexact,
eval_Power_number,
eval_Times,
)
from mathics.eval.nevaluator import eval_N
from mathics.eval.numerify import numerify

Expand Down Expand Up @@ -535,15 +544,15 @@ class Power(BinaryOperator, MPMathFunction):
# Remember to up sympy doc link when this is corrected
sympy_name = "Pow"

def eval_exp(self, x, evaluation):
"Power[E, x]"
return eval_Exponential(x)

def eval_check(self, x, y, evaluation):
"Power[x_, y_]"

# Power uses MPMathFunction but does some error checking first
if isinstance(x, Number) and x.is_zero:
if isinstance(y, Number):
y_err = y
else:
y_err = eval_N(y, evaluation)
# if x is zero
if x.is_zero:
y_err = y if isinstance(y, Number) else eval_N(y, evaluation)
if isinstance(y_err, Number):
py_y = y_err.round_to_float(permit_complex=True).real
if py_y > 0:
Expand All @@ -557,17 +566,47 @@ def eval_check(self, x, y, evaluation):
evaluation.message(
"Power", "infy", Expression(SymbolPower, x, y_err)
)
return SymbolComplexInfinity
if isinstance(x, Complex) and x.real.is_zero:
yhalf = Expression(SymbolTimes, y, RationalOneHalf)
factor = self.eval(Expression(SymbolSequence, x.imag, y), evaluation)
return Expression(
SymbolTimes, factor, Expression(SymbolPower, IntegerM1, yhalf)
)

result = self.eval(Expression(SymbolSequence, x, y), evaluation)
if result is None or result != SymbolNull:
return result
return SymbolComplexInfinity

# If x and y are inexact numbers, use the numerical function

if x.is_inexact() and y.is_inexact():
try:
return eval_Power_inexact(x, y)
except OverflowError:
evaluation.message("General", "ovfl")
return Expression(SymbolOverflow)

# Tries to associate powers a^b^c-> a^(b*c)
assoc = associate_powers(x, y)
if not assoc.has_form("Power", 2):
return assoc

assoc = numerify(assoc, evaluation)
x, y = assoc.elements
# If x and y are numbers
if isinstance(x, Number) and isinstance(y, Number):
try:
return eval_Power_number(x, y)
except OverflowError:
evaluation.message("General", "ovfl")
return Expression(SymbolOverflow)

# if x or y are inexact, leave the expression
# as it is:
if x.is_inexact() or y.is_inexact():
return assoc

# Finally, try to convert to sympy
base_sp, exp_sp = x.to_sympy(), y.to_sympy()
if base_sp is None or exp_sp is None:
# If base or exp can not be converted to sympy,
# returns the result of applying the associative
# rule.
return assoc

result = from_sympy(sympy.Pow(base_sp, exp_sp))
return result.evaluate_elements(evaluation)


class Sqrt(SympyFunction):
Expand Down
105 changes: 102 additions & 3 deletions mathics/eval/arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# -*- coding: utf-8 -*-

"""
arithmetic-related evaluation functions.
helper functions for arithmetic evaluation, which do not
depends on the evaluation context. Conversions to Sympy are
used just as a last resource.
Many of these do do depend on the evaluation context. Conversions to Sympy are
used just as a last resource.
Expand Down Expand Up @@ -320,6 +322,28 @@ def eval_complex_sign(n: BaseElement) -> Optional[BaseElement]:
return sign or eval_complex_sign(expr)


def eval_Sign_number(n: Number) -> Number:
"""
Evals the absolute value of a number.
"""
if n.is_zero:
return Integer0
if isinstance(n, (Integer, Rational, Real)):
return Integer1 if n.value > 0 else IntegerM1
if isinstance(n, Complex):
abs_sq = eval_add_numbers(
*(eval_multiply_numbers(x, x) for x in (n.real, n.imag))
)
criteria = eval_add_numbers(abs_sq, IntegerM1)
if test_zero_arithmetic_expr(criteria):
return n
if n.is_inexact():
return eval_multiply_numbers(n, eval_Power_number(abs_sq, RealM0p5))
if test_zero_arithmetic_expr(criteria, numeric=True):
return n
return eval_multiply_numbers(n, eval_Power_number(abs_sq, RationalMOneHalf))


def eval_mpmath_function(
mpmath_function: Callable, *args: Number, prec: Optional[int] = None
) -> Optional[Number]:
Expand Down Expand Up @@ -347,6 +371,31 @@ def eval_mpmath_function(
return call_mpmath(mpmath_function, tuple(mpmath_args), prec)


def eval_Exponential(exp: BaseElement) -> BaseElement:
"""
Eval E^exp
"""
# If both base and exponent are exact quantities,
# use sympy.

if not exp.is_inexact():
exp_sp = exp.to_sympy()
if exp_sp is None:
return None
return from_sympy(sympy.Exp(exp_sp))

prec = exp.get_precision()
if prec is not None:
if exp.is_machine_precision():
number = mpmath.exp(exp.to_mpmath())
result = from_mpmath(number)
return result
else:
with mpmath.workprec(prec):
number = mpmath.exp(exp.to_mpmath())
return from_mpmath(number, prec)


def eval_Plus(*items: BaseElement) -> BaseElement:
"evaluate Plus for general elements"
numbers, items_tuple = segregate_numbers_from_sorted_list(*items)
Expand Down Expand Up @@ -645,8 +694,58 @@ def eval_Times(*items: BaseElement) -> BaseElement:
)


def associate_powers(expr: BaseElement, power: BaseElement = Integer1) -> BaseElement:
"""
base^a^b^c^...^power -> base^(a*b*c*...power)
provided one of the following cases
* `a`, `b`, ... `power` are all integer numbers
* `a`, `b`,... are Rational/Real number with absolute value <=1,
and the other powers are not integer numbers.
* `a` is not a Rational/Real number, and b, c, ... power are all
integer numbers.
"""
powers = []
base = expr
if power is not Integer1:
powers.append(power)

while base.has_form("Power", 2):
previous_base, outer_power = base, power
base, power = base.elements
if len(powers) == 0:
if power is not Integer1:
powers.append(power)
continue
if power is IntegerM1:
powers.append(power)
continue
if isinstance(power, (Rational, Real)):
if abs(power.value) < 1:
powers.append(power)
continue
# power is not rational/real and outer_power is integer,
elif isinstance(outer_power, Integer):
if power is not Integer1:
powers.append(power)
if isinstance(power, Integer):
continue
else:
break
# in any other case, use the previous base and
# exit the loop
base = previous_base
break

if len(powers) == 0:
return base
elif len(powers) == 1:
return Expression(SymbolPower, base, powers[0])
result = Expression(SymbolPower, base, Expression(SymbolTimes, *powers))
return result


def eval_add_numbers(
*numbers: Number,
*numbers: List[Number],
) -> BaseElement:
"""
Add the elements in ``numbers``.
Expand Down Expand Up @@ -693,7 +792,7 @@ def eval_inverse_number(n: Number) -> Number:
return eval_Power_number(n, IntegerM1)


def eval_multiply_numbers(*numbers: Number) -> Number:
def eval_multiply_numbers(*numbers: Number) -> BaseElement:
"""
Multiply the elements in ``numbers``.
"""
Expand Down
8 changes: 4 additions & 4 deletions test/builtin/arithmetic/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_directed_infinity_precedence(str_expr, str_expected, msg):
("I^(2/3)", "(-1) ^ (1 / 3)", None),
# In WMA, the next test would return ``-(-I)^(2/3)``
# which is less compact and elegant...
# ("(-I)^(2/3)", "(-1) ^ (-1 / 3)", None),
("(-I)^(2/3)", "(-1) ^ (-1 / 3)", None),
("(2+3I)^3", "-46 + 9 I", None),
("(1.+3. I)^.6", "1.46069 + 1.35921 I", None),
("3^(1+2 I)", "3 ^ (1 + 2 I)", None),
Expand All @@ -208,15 +208,15 @@ def test_directed_infinity_precedence(str_expr, str_expected, msg):
# sympy, which produces the result
("(3/Pi)^(-I)", "(3 / Pi) ^ (-I)", None),
# Association rules
# ('(a^"w")^2', 'a^(2 "w")', "Integer power of a power with string exponent"),
('(a^"w")^2', 'a^(2 "w")', "Integer power of a power with string exponent"),
('(a^2)^"w"', '(a ^ 2) ^ "w"', None),
('(a^2)^"w"', '(a ^ 2) ^ "w"', None),
("(a^2)^(1/2)", "Sqrt[a ^ 2]", None),
("(a^(1/2))^2", "a", None),
("(a^(1/2))^2", "a", None),
("(a^(3/2))^3.", "(a ^ (3 / 2)) ^ 3.", None),
# ("(a^(1/2))^3.", "a ^ 1.5", "Power associativity rational, real"),
# ("(a^(.3))^3.", "a ^ 0.9", "Power associativity for real powers"),
("(a^(1/2))^3.", "a ^ 1.5", "Power associativity rational, real"),
("(a^(.3))^3.", "a ^ 0.9", "Power associativity for real powers"),
("(a^(1.3))^3.", "(a ^ 1.3) ^ 3.", None),
# Exponentials involving expressions
("(a^(p-2 q))^3", "a ^ (3 p - 6 q)", None),
Expand Down
43 changes: 31 additions & 12 deletions test/format/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,34 +456,53 @@
"Sqrt[1/(1+1/(1+1/a))]": {
"msg": "SqrtBox",
"text": {
"System`StandardForm": "Sqrt[1 / (1+1 / (1+1 / a))]",
"System`TraditionalForm": "Sqrt[1 / (1+1 / (1+1 / a))]",
"System`InputForm": "Sqrt[1 / (1 + 1 / (1 + 1 / a))]",
"System`OutputForm": "Sqrt[1 / (1 + 1 / (1 + 1 / a))]",
"System`StandardForm": "1 / Sqrt[1+1 / (1+1 / a)]",
"System`TraditionalForm": "1 / Sqrt[1+1 / (1+1 / a)]",
"System`InputForm": "1 / Sqrt[1 + 1 / (1 + 1 / a)]",
"System`OutputForm": "1 / Sqrt[1 + 1 / (1 + 1 / a)]",
},
"mathml": {
"System`StandardForm": (
"<msqrt> <mfrac><mn>1</mn> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mi>a</mi></mfrac></mrow></mfrac></mrow></mfrac> </msqrt>",
(
r"<mfrac><mn>1</mn> <msqrt> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> "
r"<mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mi>a</mi></mfrac></mrow></mfrac></mrow> "
r"</msqrt></mfrac>"
),
"Fragile!",
),
"System`TraditionalForm": (
"<msqrt> <mfrac><mn>1</mn> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mi>a</mi></mfrac></mrow></mfrac></mrow></mfrac> </msqrt>",
(
r"<mfrac><mn>1</mn> <msqrt> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> "
r"<mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mi>a</mi></mfrac></mrow></mfrac></mrow> "
r"</msqrt></mfrac>"
),
"Fragile!",
),
"System`InputForm": (
"<mrow><mi>Sqrt</mi> <mo>[</mo> <mrow><mtext>1</mtext> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mo>(</mo> <mrow><mtext>1</mtext> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mtext>1</mtext> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mo>(</mo> <mrow><mtext>1</mtext> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mtext>1</mtext> <mtext>&nbsp;/&nbsp;</mtext> <mi>a</mi></mrow></mrow> <mo>)</mo></mrow></mrow></mrow> <mo>)</mo></mrow></mrow> <mo>]</mo></mrow>",
(
r"<mrow><mtext>1</mtext> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mi>Sqrt</mi> <mo>[</mo> "
r"<mrow><mtext>1</mtext> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mtext>1</mtext> <mtext>&nbsp;/&nbsp;</mtext> "
r"<mrow><mo>(</mo> <mrow><mtext>1</mtext> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mtext>1</mtext> <mtext>"
r"&nbsp;/&nbsp;</mtext> <mi>a</mi></mrow></mrow> <mo>)</mo></mrow></mrow></mrow> <mo>]</mo></mrow></mrow>"
),
"Fragile!",
),
"System`OutputForm": (
"<mrow><mi>Sqrt</mi> <mo>[</mo> <mrow><mn>1</mn> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mo>(</mo> <mrow><mn>1</mn> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mn>1</mn> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mo>(</mo> <mrow><mn>1</mn> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mn>1</mn> <mtext>&nbsp;/&nbsp;</mtext> <mi>a</mi></mrow></mrow> <mo>)</mo></mrow></mrow></mrow> <mo>)</mo></mrow></mrow> <mo>]</mo></mrow>",
(
r"<mrow><mn>1</mn> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mi>Sqrt</mi> <mo>["
r"</mo> <mrow><mn>1</mn> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mn>1</mn> "
r"<mtext>&nbsp;/&nbsp;</mtext> <mrow><mo>(</mo> <mrow><mn>1</mn> <mtext>"
r"&nbsp;+&nbsp;</mtext> <mrow><mn>1</mn> <mtext>&nbsp;/&nbsp;</mtext> "
r"<mi>a</mi></mrow></mrow> <mo>)</mo></mrow></mrow></mrow> <mo>]</mo></mrow></mrow>"
),
"Fragile!",
),
},
"latex": {
"System`StandardForm": "\\sqrt{\\frac{1}{1+\\frac{1}{1+\\frac{1}{a}}}}",
"System`TraditionalForm": "\\sqrt{\\frac{1}{1+\\frac{1}{1+\\frac{1}{a}}}}",
"System`InputForm": "\\text{Sqrt}\\left[1\\text{ / }\\left(1\\text{ + }1\\text{ / }\\left(1\\text{ + }1\\text{ / }a\\right)\\right)\\right]",
"System`OutputForm": "\\text{Sqrt}\\left[1\\text{ / }\\left(1\\text{ + }1\\text{ / }\\left(1\\text{ + }1\\text{ / }a\\right)\\right)\\right]",
"System`StandardForm": "\\frac{1}{\\sqrt{1+\\frac{1}{1+\\frac{1}{a}}}}",
"System`TraditionalForm": "\\frac{1}{\\sqrt{1+\\frac{1}{1+\\frac{1}{a}}}}",
"System`InputForm": r"1\text{ / }\text{Sqrt}\left[1\text{ + }1\text{ / }\left(1\text{ + }1\text{ / }a\right)\right]",
"System`OutputForm": r"1\text{ / }\text{Sqrt}\left[1\text{ + }1\text{ / }\left(1\text{ + }1\text{ / }a\right)\right]",
},
},
# Grids, arrays and matrices
Expand Down

0 comments on commit 78496ef

Please sign in to comment.