SymPy understands that Sum
and +
commute.
from sympy import symbols, Idx, IndexedBase, Sum
i, n = symbols("i n", cls=Idx)
x = IndexedBase("x", shape=(n,))
s = Sum(x[i] + 1, (i, 1, n))
t = Sum(x[i], (i, 1, n)) + Sum(1, (i, 1, n))
assert s.equals(t) # OK
But for complex expressions that does not work.
from sympy import symbols, Idx, IndexedBase, Sum
a, b = symbols("a b")
i, n = symbols("i n", cls=Idx)
w = IndexedBase("w", shape=(n,))
x = IndexedBase("x", shape=(n,))
y = IndexedBase("y", shape=(n,))
sum = lambda e: Sum(e, (i, 1, n))
sw = sum(w[i])
mx = sum(w[i] * x[i]) / sw
my = sum(w[i] * y[i]) / sw
d = w[i] * ((a * (x[i] - mx) - (y[i] - my)))**2
e = w[i] * (a * mx + b - my)**2
f = w[i] * 2 * (a * (x[i] - mx) - (y[i] - my)) * (a * mx + b - my)
s = sum(d + e + f)
t = sum(d) + sum(e) + sum(f)
assert s.equals(t) # The assert fails
How can we explain to SymPy that this transformation is actually OK?
You can use expand()
and simplify()
:
from sympy import symbols, Idx, IndexedBase, Sum, expand, simplify
a, b = symbols("a b")
i, n = symbols("i n", cls=Idx)
w = IndexedBase("w", shape=(n,))
x = IndexedBase("x", shape=(n,))
y = IndexedBase("y", shape=(n,))
sum = lambda e: Sum(e, (i, 1, n))
sw = sum(w[i])
mx = sum(w[i] * x[i]) / sw
my = sum(w[i] * y[i]) / sw
d = w[i] * ((a * (x[i] - mx) - (y[i] - my)))**2
e = w[i] * (a * mx + b - my)**2
f = w[i] * 2 * (a * (x[i] - mx) - (y[i] - my)) * (a * mx + b - my)
s = sum(d + e + f)
t = sum(d) + sum(e) + sum(f)
se, te = expand(s), expand(t)
ss, ts = simplify(se), simplify(te)
print(ss.equals(ts))
Prints
True
SymPy favors not doing automatic simplification, so it is appropriate to expect to do some to show that expressions are the same. If equals were modified to do expansion, this might have worked for you, but it only does diff = factor_terms(simplify(self - other), radical=True)
at the outset.