7

I have a decorator @pure that registers a function as pure, for example:

@pure
def rectangle_area(a,b):
    return a*b


@pure
def triangle_area(a,b,c):
    return ((a+(b+c))(c-(a-b))(c+(a-b))(a+(b-c)))**0.5/4

Next, I want to identify a newly defined pure function

def house_area(a,b,c):
    return rectangle_area(a,b) + triangle_area(a,b,c)

Obviously house_area is pure, since it only calls pure functions.

How can I discover all pure functions automatically (perhaps by using ast)

Uri Goren
  • 13,386
  • 6
  • 58
  • 110

1 Answers1

5

Assuming operators are all pure, then essentially you only need to check all the functions calls. This can indeed be done with the ast module.

First I defined the pure decorator as:

def pure(f):
    f.pure = True
    return f

Adding an attribute telling that it's pure, allows skipping early or "forcing" a function to identify as pure. This is useful if you'd need a function like math.sin to identify as pure. Additionally since you can't add attributes to builtin functions.

@pure
def sin(x):
    return math.sin(x)

All in all. Use the ast module to visit all the nodes. Then for each Call node check whether the function being called is pure.

import ast

class PureVisitor(ast.NodeVisitor):
    def __init__(self, visited):
        super().__init__()
        self.pure = True
        self.visited = visited

    def visit_Name(self, node):
        return node.id

    def visit_Attribute(self, node):
        name = [node.attr]
        child = node.value
        while child is not None:
            if isinstance(child, ast.Attribute):
                name.append(child.attr)
                child = child.value
            else:
                name.append(child.id)
                break
        name = ".".join(reversed(name))
        return name

    def visit_Call(self, node):
        if not self.pure:
            return
        name = self.visit(node.func)
        if name not in self.visited:
            self.visited.append(name)
            try:
                callee = eval(name)
                if not is_pure(callee, self.visited):
                    self.pure = False
            except NameError:
                self.pure = False

Then check whether the function has the pure attribute. If not get code and check if all the functions calls can be classified as pure.

import inspect, textwrap

def is_pure(f, _visited=None):
    try:
        return f.pure
    except AttributeError:
        pass

    try:
        code = inspect.getsource(f.__code__)
    except AttributeError:
        return False

    code = textwrap.dedent(code)
    node = compile(code, "<unknown>", "exec", ast.PyCF_ONLY_AST)

    if _visited is None:
        _visited = []

    visitor = PureVisitor(_visited)
    visitor.visit(node)
    return visitor.pure

Note that print(is_pure(lambda x: math.sin(x))) doesn't work since inspect.getsource(f.__code__) returns code on a line by line basis. So the source returned by getsource would include the print and is_pure call, thus yielding False. Unless those functions are overridden.


To verify that it works, test it by doing:

print(house_area) # Prints: True

To list through all the functions in the current module:

import sys, types

for k in dir(sys.modules[__name__]):
    v = globals()[k]
    if isinstance(v, types.FunctionType):
        print(k, is_pure(v))

The visited list keeps track of which functions have already been verified pure. This help circumvent problems related to recursion. Since the code isn't executed, the evaluation would recursively visit factorial.

@pure
def factorial(n):
    return 1 if n == 1 else n * factorial(n - 1)

Note that you might need to revise the following code. Choosing another way to obtain a function from its name.

try:
    callee = eval(name)
    if not is_pure(callee, self.visited):
        self.pure = False
except NameError:
    self.pure = False
vallentin
  • 23,478
  • 6
  • 59
  • 81
  • Great solution, thank you ! I'm puzzled as to why `is_pure(f)` is `True` where `f = lambda x: math.sin(x)` (Note that `f` is not marked as `@pure`) – Uri Goren Aug 09 '17 at 15:18
  • And on the other hand, `is_pure(lambda x:math.sin(x))` is `False` – Uri Goren Aug 09 '17 at 15:22
  • The first one is actually because it doesn't account for "math" then "sin" (I'll fix that now). The second problem is due to `inspect.getsource(f.__code__)`, which doesn't return `lambda x:math.sin(x)` but the whole line. – vallentin Aug 09 '17 at 15:37
  • @UriGoren I fixed the first problem, but as already stated the second problem isn't directly solvable while using `inspect.getsource` (unless you manually extract a substring). However it is probably easier just to avoid passing a lambda like that, and instead do the `f = lambda ...` way if needed. – vallentin Aug 09 '17 at 16:05