Search code examples
latexsympy

Error in a custom sympy function when trying to print the squared function with latex


I would like to write a custom SymPy Function with special latex output. A simplified function might look like the following:

from sympy import Function, Symbol, latex

class TestClass(Function):
    def _latex(self, printer):
        m = self.args[0]
        _m = printer._print(m)
        return _m + '(x)'

S = Symbol('S')
latex(TestClass(S)**2)

However, this gives me the error

TypeError: TestClass._latex() got an unexpected keyword argument 'exp'

Can someone please help me to understand what is going wrong here?


Solution

  • TestClass is a subclass of Function.

    TestClass(S)**2 is an instance of Pow. Let's look at the source code of the LatexPrinter class, specifically at the _print_Pow method:

            if expr.base.is_Function:
                return self._print(expr.base, exp=self._print(expr.exp))
    

    Under the hood, self._print checks if expr.base implements the _latex method. If it does (like in your test case), it calls it and pass along the keyword arguments, exp.

    So, you need to adjust your code:

    class TestClass(Function):
        def _latex(self, printer, exp=None):
            m = self.args[0]
            _m = printer.doprint(m)
            base = _m + '(x)'
            if exp is None:
                return base
            return base + "^{%s}" % exp
    
    S = Symbol('S')
    expr = TestClass(S)
    latex(expr**3)
    # out: 'S(x)^{3}'