Search code examples
pythonabstract-syntax-treepurely-functional

Identifying pure functions in python


I have a decorator @pure that registers a function as pure, for example:

@pure
def rectangle_area(a,b):
    return a*b


@pure
def triangle_area(a,b,c):
    return ((a+(b+c))(c-(a-b))(c+(a-b))(a+(b-c)))**0.5/4

Next, I want to identify a newly defined pure function

def house_area(a,b,c):
    return rectangle_area(a,b) + triangle_area(a,b,c)

Obviously house_area is pure, since it only calls pure functions.

How can I discover all pure functions automatically (perhaps by using ast)


Solution

  • Assuming operators are all pure, then essentially you only need to check all the functions calls. This can indeed be done with the ast module.

    First I defined the pure decorator as:

    def pure(f):
        f.pure = True
        return f
    

    Adding an attribute telling that it's pure, allows skipping early or "forcing" a function to identify as pure. This is useful if you'd need a function like math.sin to identify as pure. Additionally since you can't add attributes to builtin functions.

    @pure
    def sin(x):
        return math.sin(x)
    

    All in all. Use the ast module to visit all the nodes. Then for each Call node check whether the function being called is pure.

    import ast
    
    class PureVisitor(ast.NodeVisitor):
        def __init__(self, visited):
            super().__init__()
            self.pure = True
            self.visited = visited
    
        def visit_Name(self, node):
            return node.id
    
        def visit_Attribute(self, node):
            name = [node.attr]
            child = node.value
            while child is not None:
                if isinstance(child, ast.Attribute):
                    name.append(child.attr)
                    child = child.value
                else:
                    name.append(child.id)
                    break
            name = ".".join(reversed(name))
            return name
    
        def visit_Call(self, node):
            if not self.pure:
                return
            name = self.visit(node.func)
            if name not in self.visited:
                self.visited.append(name)
                try:
                    callee = eval(name)
                    if not is_pure(callee, self.visited):
                        self.pure = False
                except NameError:
                    self.pure = False
    

    Then check whether the function has the pure attribute. If not get code and check if all the functions calls can be classified as pure.

    import inspect, textwrap
    
    def is_pure(f, _visited=None):
        try:
            return f.pure
        except AttributeError:
            pass
    
        try:
            code = inspect.getsource(f.__code__)
        except AttributeError:
            return False
    
        code = textwrap.dedent(code)
        node = compile(code, "<unknown>", "exec", ast.PyCF_ONLY_AST)
    
        if _visited is None:
            _visited = []
    
        visitor = PureVisitor(_visited)
        visitor.visit(node)
        return visitor.pure
    

    Note that print(is_pure(lambda x: math.sin(x))) doesn't work since inspect.getsource(f.__code__) returns code on a line by line basis. So the source returned by getsource would include the print and is_pure call, thus yielding False. Unless those functions are overridden.


    To verify that it works, test it by doing:

    print(house_area) # Prints: True
    

    To list through all the functions in the current module:

    import sys, types
    
    for k in dir(sys.modules[__name__]):
        v = globals()[k]
        if isinstance(v, types.FunctionType):
            print(k, is_pure(v))
    

    The visited list keeps track of which functions have already been verified pure. This help circumvent problems related to recursion. Since the code isn't executed, the evaluation would recursively visit factorial.

    @pure
    def factorial(n):
        return 1 if n == 1 else n * factorial(n - 1)
    

    Note that you might need to revise the following code. Choosing another way to obtain a function from its name.

    try:
        callee = eval(name)
        if not is_pure(callee, self.visited):
            self.pure = False
    except NameError:
        self.pure = False