Search code examples
pythonsympynumexprlambdify

Sympy lambdify ImmutableDenseMatrix with numexpr


I try to accelerate the evaluation of a MutableDenseMatrix using lambdify. It works with the module 'numpy'. 'Numexpr' should be faster (as I need the evaluation to solve a large optimization problem).

A smaller example of what I am trying to do is given by

from sympy import symbols, cos, Matrix, lambdify

a11, a12, a21, a22, b11, b12, b21, b22, u = symbols("a11 a12 a21 a22 b11 b12 b21 b22 u")
A = Matrix([[a11, a12], [a21, a22]])
B = Matrix([[b11, b12], [b21, b22]])
expr = A * (B ** 2) * cos(u) + A ** (-3 / 2)
f = lambdify((A, B), expr, modules='numexpr')

It raises the error

TypeError: numexpr cannot be used with ImmutableDenseMatrix

Is there a way to use lambdify for DenseMatrices? Or another idea how to speed up the evaluation?

Thanks in advance!


Solution

  • A possible solution using numexpr is to evaluate every matrix expression on it's own. The following Code should output a python function which evaluates all matrix expressions using Numexpr.

    Numexpr with Matrices

    import numpy as np
    import sympy
    
    def lambdify_numexpr(args,expr,expr_name):
        from sympy.printing.lambdarepr import NumExprPrinter as Printer
        printer = Printer({'fully_qualified_modules': False, 'inline': True,'allow_unknown_functions': False})
    
        s=""
        s+="import numexpr as ne\n"
        s+="from numpy import *\n"
        s+="\n"
    
        #get arg_names
        arg_names=[]
        arg_names_str=""
        for i in range(len(args)):
            name=[ k for k,v in globals().items() if v is args[i]][0]
            arg_names_str+=name
            arg_names.append(name)
    
            if i< len(args)-1:
                arg_names_str+=","
    
        #Write header
        s+="def "+expr_name+"("+arg_names_str+"):\n"
    
        #unroll array
        for ii in range(len(args)):
            arg=args[ii]
            if arg.is_Matrix:
                for i in range(arg.shape[0]):
                    for j in range(arg.shape[1]):
                        s+="    "+ str(arg[i,j])+" = " + arg_names[ii]+"["+str(i)+","+str(j)+"]\n"
    
        s+="    \n"
        #If the expr is a matrix
        if expr.is_Matrix:
            #write expressions
            for i in range(len(expr)):
                s+="    "+ "res_"+str(i)+" = ne."+printer.doprint(expr[i])+"\n"
                s+="    \n"
    
            res_counter=0
            #write array
            s+="    return concatenate(("
            for i in range(expr.shape[0]):
                s+="("
                for j in range(expr.shape[1]):
                    s+="res_"+str(res_counter)+","
                    res_counter+=1
                s+="),"
            s+="))\n"
    
        #If the expr is not a matrix
        else:
            s+="    "+ "return ne."+printer.doprint(expr)+"\n"
        return s