Search code examples
pythonsympy

Sympy: hierarchically collect an equation with more than one factor expression?


Can we collect/factor an symbolic equation using sympy with more than one factor expression (and hierarchically)? I provide the following example: a long calculation result into the expanded equation f. In my computation, equations are larger and more complex.

import sympy
a,b,m,n = sympy.symbols('a b m n')
f = -a**2*b*m**2 - a**2*b*m*n - a**2*b*m + a**2*m**2*n + a**2*m**2 + 2*a**2*m + a**2*n**2 - a*b*m - a*b*n + 5*a*m + a*n**2 + a*n + b*m**2 - 2*b*m*n - b*m - b*n**2 - m**2 + 2*m*n + n**2 + 1
g=sympy.collect(f,(a**2,1-b))

I expect a function (collect?) to hierarchically aggregate terms around a**2 first, and then (1-b) so that f is written as equation g bellow

g = a**2 * ( (1-b) * (m+n*m+m**2)+m+n**2-m*n+m**2*n)+a*((1-b)*(m+n)+n**2+4*m)+(1-b) * (m+n**2-m**2+2*m*n) +1 - m

where the equation is factorized/grouped around a**2 then (1-b)

Of course, f=g that can be verified in my example case:

h = sympy.simplify(f-g)
print(h)

Edit: More explanations of the compact but efficient answer from smichr

import sympy
a,b,m,n, y = sympy.symbols('a b m n y')
f1 = -a**2*b*m**2 - a**2*b*m*n - a**2*b*m + a**2*m**2*n + a**2*m**2 + 2*a**2*m + a**2*n**2 - a*b*m - a*b*n + 5*a*m + a*n**2 + a*n + b*m**2 - 2*b*m*n - b*m - b*n**2 - m**2 + 2*m*n + n**2 + 1
g1=sympy.collect(f1,(a**2,1-b))

f1_1 = f1.subs(b,1-y)   # change b by 1-y. here, the trick is that it seems easier to sympy to substitute one term of the original equation by an expression, as I wanted to factorise by y=1-b <-> b=1-y  
f1_2 = f1_1.expand()    # expanding the product implying y to ease its collect after
f1_3 = sympy.collect(f1_2, a)       # collecting the expanded equation by a to get the factor a² (and a also)
f1_4 = sympy.collect(f1_3,y)        # collecting the already factorised-by-a equation, here is the trick to hierarchically collect by two factors?
f1_5 = f1_4.subs(y,1-b)             # we change back y by its initial term = 1-b
# last trick, if we want to factorise by the exact factor (and not the powers of it), we can mention exact=true to sympy.collect 
print(f1_5)

It results in:

a**2 * [m**2*n - m*n + m + n**2 + (1 - b) * (m**2 + m*n + m)] +
... a * [4*m + n**2 + (1 - b)*(m + n)] +
... (1 - b)*(-m**2 + 2*m*n + m + n**2) + 1 - m

Which is hierarchically factorised first by a (including a² and a) then by (1-b). Thank smichr and stackoverflow.


Solution

  • Collection works better on symbols than sums, so try the following which gives the same thing as the g you desired:

    >>> from sympy.abc import y,a,b,m,n
    >>> f = (-a**2*b*m**2 - a**2*b*m*n - a**2*b*m + 
    ...    a**2*m**2*n + a**2*m**2 + 2*a**2*m + 
    ...    a**2*n**2 - a*b*m - a*b*n + 5*a*m + a*n**2 +
    ...    a*n + b*m**2 - 2*b*m*n - b*m - b*n**2 - m**2 + 
    ...    2*m*n + n**2 + 1)
    ...
    >>> g = (a**2*((1-b)*(m+n*m+m**2)+m+n**2-m*n+m**2*n)+
    ...     a*((1-b)*(m+n)+n**2+4*m)+(1-b)*(m+n**2-m**2+2*m*n)+1-m)
    >>> g == collect(collect(f.subs(b,1-y).expand(),a),y).subs(y,1-b)
    True
    

    The expand expands the products like m*(1 - y) -> m - m*y so you can collect on the y. I use a (instead of a**2 and the flag exact=True) because your g expression is collected on a, too, not only a**2.