I've got some huge matricies to export, which contain only sin(q), cos(q) and sums/muls of those. Sympy can calculate and export this to octave - which is awesome!
However, since these are big matricies I need some sort of cse
or even better dedicated optimization.
I found this great tutorial for C code with cse. So I tried myself at porting it, but I failed at some details in the printer class. I think it is an infinite recursion resulting in RecursionError: maximum recursion depth exceeded
.
My question is: is there an example how sympy-octave codegen and optimization go together? Or can someone help me get the attached mwe running?
import sympy as sp
t = sp.symbols('t')
from sympy.printing.octave import OctaveCodePrinter
from sympy.printing.octave import Assignment
class matlabMatrixPrinter(OctaveCodePrinter):
def _print_ImmutableDenseMatrix(self, expr):
sub_exprs, simplified = sp.cse(expr)
lines = []
for var, sub_expr in sub_exprs:
lines.append( self._print(Assignment(var, sub_expr)))
M = sp.MatrixSymbol('M', *expr.shape)
return '\n'.join(lines) + '\n' + self._print(Assignment(M, expr))
tmp = sp.sin(t)+sp.sin(t)**2
tmp = sp.ImmutableDenseMatrix((1,1,tmp))
se, ex = sp.cse(tmp)
print((ex,se))
print('\n')
#tmp = sp.Matrix([2*sp.sin(t),sp.sin(t)])
p = matlabMatrixPrinter()
print(p.doprint(tmp))
Edit: I now figured out, that the second assignment in the return statement runs the function _print_ImmutableDenseMatrix as well, so this ends up being a recursion. I don't know why in the tutorial this is no problem for C-code, but here it runs recursively. It seems to be a problem of only the simplified expression itself which cannot call the self._print function. Maybe someone knows something about these printers and how one should print matricies and this single assignment?!
After a lot of experiments I feel like I am still only understanding a bit of the intentions behind the codePrinter's intentional workflow. Yet, I wrote a subclass which does exactly as I intended (careful, since this probably doesn't work with anything other than matricies!).
Maybe this is of use to someone! For me it definitely validates sympy as a working tool, since otherwise thousands of sin
evaluations would be absolutely unviable code.
I would still be very much interested in comments and thoughts of someone, who can do knows how these features SHOULD be implemented!
import sympy as sp
t = sp.symbols('t')
from sympy.printing.octave import OctaveCodePrinter
from sympy.printing.octave import Assignment
class matlabMatrixPrinter(OctaveCodePrinter):
def print2(self,expr_list,names=None):
sub_exprs, simplified = sp.cse(expr_list)
lines = []
for var, sub_expr in sub_exprs:
lines.append(self._print(Assignment(var, sub_expr)))
lines.append('')
for k,expr in enumerate(simplified):
if names:
M = sp.MatrixSymbol(names[k],*expr.shape)
else:
M = sp.MatrixSymbol('M{k}'.format(k=k), *expr.shape)
lines.append(self._print(Assignment(M,expr)))
result = ''
return '\n'.join(lines)
tmp = sp.Matrix([sp.sin(t)+sp.sin(t)**2 ])
tmp2 = sp.Matrix([sp.sin(t),sp.cos(t),2*sp.sin(t),sp.cos(t)**2])
p = matlabMatrixPrinter()
#print(p.print2([tmp,tmp2]))
print(p.print2([tmp,tmp2],['scalar_matrix','matrix']));
And this gives the expected output:
x0 = sin(t);
x1 = cos(t);
scalar_matrix = x0.^2 + x0;
matrix = [x0; x1; 2*x0; x1.^2];
As described above: use at your own risk :)