Search code examples
pythoncontinuations

Cleanest way to compute less in a function by implicitly knowing what the callback / caller function is in Python?


The specific use case I have in mind is that say I have two matrices like

A = [[1, 1], [1, 1]]
B = [[2, 2], [2, 2]]

and I want to write a function like multiply to compute their dot product like this:

def multiply(X, Y):
    # some code to compute the ij-th entry of the resultant matrix
    return result

however immediately after this computation say that I am applying a trace operation (sum of the diagonal entries). Then obviously I could not care less about the entries in the resultant matrix that are not diagonal. So my question is, what is the cleanest way in Python to tell multiply that the only operation on what will be returned by it will be an operation that only needs small parts of the result (so as to only compute that result). The first thing I think of is something like

def multiply(X, Y, only_diag=False):
    if only_diag:
        # only return i, j entries
    else:
        # return all entries

result = trace(multiply(A, B, only_diag=True))

However I'm interested to know if there is a way to implement multiply such that it recognizes the function trace as being the caller function, and so implicity it knows that only_diag=True.

Thanks :)


Solution

  • I don't see a clean way to know which function is called afterwards as in your example.

    In the code trace(multiply(A, B)), the multiply() function is called before trace() rather than being called by trace().

    In CPython, if you want to know the caller function, then the sys._getframe() function can help:

    import sys
    from inspect import getframeinfo
    
    def multiply(X, Y):
        print getframeinfo(sys._getframe(1))
    
    def trace(Z):
        pass
    
    def multiply_then_trace(X, Y):
        Z = multiply(X, Y)
        return trace(Z)
    
    if __name__ == '__main__':
        A = [[1, 1], [1, 1]]
        B = [[2, 2], [2, 2]]
        multiply_then_trace(A, B)
    

    This prints:

    Traceback(filename='/Users/raymond/Documents/tmp3.py',
              lineno=11,
              function='multiply_then_trace',
              code_context=['    Z = multiply(X, Y)\n'],
              index=0)
    

    To achieve your goal of faster code execution, the easiest thing to do is just write a trace_multiply(X, Y) function that has custom optimized code and invoke it explicitly rather than trying to cobble together automatic detection.

    Another alternative is to use lazy evaluation of the matrix multiplication. That way, you can skip the parts of the multiply that you don't need later.

    Hope this helps :-)