It's not clear from your question if you're asking about how to insert nodes into an AST tree at a low level, or more specifically about how to do node insertions with a higher level tool to walk the AST tree (e.g. a subclass of ast.NodeVisitor
or astor.TreeWalk
).
Inserting nodes at a low level is exceedingly easy. You just use list.insert
on an appropriate list in the tree. For instance, here's some code that adds the last of the three print
calls you want (the other two would be almost as easy, they'd just require more indexing). Most of the code is building the new AST node for the print call. The actual insertion is very short:
source = """
def some_function(param):
if param == 0:
return case_0(param)
elif param < 0:
return negative_case(param)
return all_other_cases(param)
"""
tree = ast.parse(source) # parse an ast tree from the source code
# build a new tree of AST nodes to insert into the main tree
message = ast.Str("Seems we're in the general case, calling all_other_cases")
print_func = ast.Name("print", ast.Load())
print_call = ast.Call(print_func, [message], []) # add two None args in Python<=3.4
print_statement = ast.Expr(print_call)
tree.body[0].body.insert(1, print_statement) # doing the actual insert here!
# now, do whatever you want with the modified ast tree.
print(astor.to_source(tree))
The output will be:
def some_function(param):
if param == 0:
return case_0(param)
elif param < 0:
return negative_case(param)
print("Seems we're in the general case, calling all_other_cases")
return all_other_cases(param)
(Note that the arguments for ast.Call
changed between Python 3.4 and 3.5+. If you're using an older version of Python, you may need to add two additional None
arguments: ast.Call(print_func, [message], [], None, None)
)
If you're using a higher level approach, things are a little bit trickier, since the code needs to figure out where to insert the new nodes, rather than using your own knowledge of the input to hard code things.
Here's a quick and dirty implementation of a TreeWalk
subclass that adds a print call as the statement before any statement that has a Call
node under it. Note that Call
nodes include calls to classes (to create instances), not only function calls. This code only handles the outermost of a nested set of calls, so if the code had foo(bar())
the inserted print
will only mention foo
:
class PrintBeforeCall(astor.TreeWalk):
def pre_body_name(self):
body = self.cur_node
print_func = ast.Name("print", ast.Load())
for i, child in enumerate(body[:]):
self.__name = None
self.walk(child)
if self.__name is not None:
message = ast.Str("Calling {}".format(self.__name))
print_statement = ast.Expr(ast.Call(print_func, [message], []))
body.insert(i, print_statement)
self.__name = None
return True
def pre_Call(self):
self.__name = self.cur_node.func.id
return True
You'd call it like this:
source = """
def some_function(param):
if param == 0:
return case_0(param)
elif param < 0:
return negative_case(param)
return all_other_cases(param)
"""
tree = ast.parse(source)
walker = PrintBeforeCall() # create an instance of the TreeWalk subclass
walker.walk(tree) # modify the tree in place
print(astor.to_source(tree)
The output this time is:
def some_function(param):
if param == 0:
print('Calling case_0')
return case_0(param)
elif param < 0:
print('Calling negative_case')
return negative_case(param)
print('Calling all_other_cases')
return all_other_cases(param)
That's not quite the exact messages you wanted, but it's close. The walker can't describe the cases being handled in detail since it only looks at the names functions being called, not the conditions that got it there. If you have a very well defined set of things to look for, you could perhaps change it to look at the ast.If
nodes, but I suspect that would be a lot more challenging.