I'm digging into SymPy's code generation capabilities and am struggling with a few basic things. Ironically, of all the languages that SymPy supports for code generation, the documentation for python code generation seems to be the most minimal/lacking. I would like to use numpy as np
as the default for all math functions in the code conversion but am struggling. Looking at the source code (since it is not documented), it looks like you can input a settings
dict that has a user_functions
field which maps SymPy functions to custom functions. I have a basic example below:
import sympy as sp
from sympy.printing.pycode import PythonCodePrinter
class MyPrinter(PythonCodePrinter):
def __init__(self):
super().__init__({'user_functions':{'cos':'np.cos', 'sin':'np.sin', 'sqrt':'np.sqrt'}})
x = sp.symbols('x')
expr = sp.sqrt(x) + sp.cos(x)
mpr = MyPrinter()
mpr.doprint(expr)
This produces the following output:
'math.sqrt(x) + np.cos(x)'
You can see that the mapping worked correctly for cos
but not for sqrt
.
There exists a sympy.printing.numpy.NumPyPrinter
class that uses numpy
functions instead of math
functions. The docstring for its __init__
says:
module
specifies the array module to use, currently 'NumPy', 'CuPy' or 'JAX'."
but the method doesn't take a module
argument. Moreover, the name of the module is hardcoded as numpy
in the source code. Also, the lines that define the known functions and constants in the numpy module use a hardcoded 'numpy.'
to create the values for the _kc
and _kf
dictionaries, although there exist separate classes CuPyPrinter
and JAXPrinter
that define their own _module
, _kf
, and _kc
variables
IMO the simplest way would be to extend the NumPyPrinter
class so that it takes a module
argument and uses it to create its _kf
and _kc
dictionaries, which are then used by the original __init__
to create the known_functions
and known_constants
dictionaries:
import sympy.printing.numpy as spn
class NumPyPrinter(spn.NumPyPrinter):
def __init__(self, settings=None, module='numpy'):
self._module = module
m = module + "."
self._kf = {k: m + v for k, v in spn._known_functions_numpy.items()}
self._kc = {k: m + v for k, v in spn._known_constants_numpy.items()}
super().__init__(settings=settings)
Now, using this printer to print your expression gives the expected output: np.sqrt(x) + np.cos(x)
x = sp.symbols('x')
expr = sp.sqrt(x) + sp.cos(x)
npr = NumPyPrinter(module='np')
print(npr.doprint(expr))