5

I'm currently playing around with abstract syntax trees, using the ast and astor modules. The documentation taught me how to retrieve and pretty-print source code for various functions, and various examples on the web show how to modify parts of the code by replacing the contents of one line with another or changing all occurrences of + to *.

However, I would like to insert additional code in various places, specifically when a function calls another function. For instance, the following hypothetical function:

def some_function(param):
    if param == 0:
       return case_0(param)
    elif param < 0:
       return negative_case(param)
    return all_other_cases(param)

would yield (once we've used astor.to_source(modified_ast)):

def some_function(param):
    if param == 0:
       print ("Hey, we're calling case_0")
       return case_0(param)
    elif param < 0:
       print ("Hey, we're calling negative_case")
       return negative_case(param)
    print ("Seems we're in the general case, calling all_other_cases")
    return all_other_cases(param)

Is this possible with abstract syntax trees? (note: I'm aware that decorating functions that are called would produce the same results when running the code, but this is not what I'm after; I need to actually output the modified code, and insert more complicated things than print statements).

Anthony Labarre
  • 2,745
  • 1
  • 28
  • 39
  • Do you want to make such a translation at compile-time, or during run-time with arbitrarily defined functions? – poke Jan 20 '17 at 13:25
  • Preferably during run-time with arbitrarily defined functions, but I'm not rejecting other options yet. – Anthony Labarre Jan 20 '17 at 13:55
  • Well, at run-time, I don’t think you can use `ast` at all since at the time the function definition ran, you are only left with the already compiled byte code. So you’re looking for modifying byte code at runtime then. Maybe [this related question](http://stackoverflow.com/questions/33348067/modifying-python-bytecode) helps. – poke Jan 20 '17 at 14:01
  • [This PyCon talk by Ryan Kelly](https://www.youtube.com/watch?v=ve7lLHtJ9l8) might also be interesting if you want to go at changing byte code. – poke Jan 20 '17 at 14:04
  • Do you need help specific to one of the AST tree-walking classes in the `ast` or `astor` modules, or do you just need general help about how to insert new nodes into an AST? The latter's dead easy (you only need `list.insert`). For the former, it would help a lot if you showed the tree walking code you're using (rather than only the code you want to modify). – Blckknght Jan 20 '17 at 15:01

1 Answers1

5

It's not clear from your question if you're asking about how to insert nodes into an AST tree at a low level, or more specifically about how to do node insertions with a higher level tool to walk the AST tree (e.g. a subclass of ast.NodeVisitor or astor.TreeWalk).

Inserting nodes at a low level is exceedingly easy. You just use list.insert on an appropriate list in the tree. For instance, here's some code that adds the last of the three print calls you want (the other two would be almost as easy, they'd just require more indexing). Most of the code is building the new AST node for the print call. The actual insertion is very short:

source = """
def some_function(param):
    if param == 0:
       return case_0(param)
    elif param < 0:
       return negative_case(param)
    return all_other_cases(param)
"""

tree = ast.parse(source) # parse an ast tree from the source code

# build a new tree of AST nodes to insert into the main tree
message = ast.Str("Seems we're in the general case, calling all_other_cases")
print_func = ast.Name("print", ast.Load())
print_call = ast.Call(print_func, [message], []) # add two None args in Python<=3.4
print_statement = ast.Expr(print_call)

tree.body[0].body.insert(1, print_statement) # doing the actual insert here!

# now, do whatever you want with the modified ast tree.
print(astor.to_source(tree))

The output will be:

def some_function(param):
    if param == 0:
        return case_0(param)
    elif param < 0:
        return negative_case(param)
    print("Seems we're in the general case, calling all_other_cases")
    return all_other_cases(param)

(Note that the arguments for ast.Call changed between Python 3.4 and 3.5+. If you're using an older version of Python, you may need to add two additional None arguments: ast.Call(print_func, [message], [], None, None))

If you're using a higher level approach, things are a little bit trickier, since the code needs to figure out where to insert the new nodes, rather than using your own knowledge of the input to hard code things.

Here's a quick and dirty implementation of a TreeWalk subclass that adds a print call as the statement before any statement that has a Call node under it. Note that Call nodes include calls to classes (to create instances), not only function calls. This code only handles the outermost of a nested set of calls, so if the code had foo(bar()) the inserted print will only mention foo:

class PrintBeforeCall(astor.TreeWalk):
    def pre_body_name(self):
        body = self.cur_node
        print_func = ast.Name("print", ast.Load())
        for i, child in enumerate(body[:]):
            self.__name = None
            self.walk(child)
            if self.__name is not None:
                message = ast.Str("Calling {}".format(self.__name))
                print_statement = ast.Expr(ast.Call(print_func, [message], []))
                body.insert(i, print_statement)
        self.__name = None
        return True

    def pre_Call(self):
        self.__name = self.cur_node.func.id
        return True

You'd call it like this:

source = """
def some_function(param):
    if param == 0:
       return case_0(param)
    elif param < 0:
       return negative_case(param)
    return all_other_cases(param)
"""

tree = ast.parse(source)

walker = PrintBeforeCall()   # create an instance of the TreeWalk subclass
walker.walk(tree)   # modify the tree in place

print(astor.to_source(tree)

The output this time is:

def some_function(param):
    if param == 0:
        print('Calling case_0')
        return case_0(param)
    elif param < 0:
        print('Calling negative_case')
        return negative_case(param)
    print('Calling all_other_cases')
    return all_other_cases(param)

That's not quite the exact messages you wanted, but it's close. The walker can't describe the cases being handled in detail since it only looks at the names functions being called, not the conditions that got it there. If you have a very well defined set of things to look for, you could perhaps change it to look at the ast.If nodes, but I suspect that would be a lot more challenging.

Blckknght
  • 100,903
  • 11
  • 120
  • 169