I'm trying to learn about AST and NodeTransformer through this hypothetical case:
*Replace the keywords 'elif/else' while still ensuring the program output remains the same. *
My approach is to use introduce a temporary variable for the if 'block' and test this to determine if the 'elif/else' parts for the block should be executed.
e.g. Given:
code='''
if a:
print (True)
else:
print (False)
'''
The transformed code is:
_boolIf0 = False
if a:
_boolIf0 = True
print (True)
if _boolIf0 == False
print (False)
I created class ParentLinks to populate the 'parent' attribute for every node in the tree. I then created sub-class ReplaceElse to make updates to a node's ('If' in our example) parent (or any other ancestor). As a first step, I simply want to introduce this temp variable at the right place in the AST.
class ParentLinks(ast.NodeTransformer):
parent = None
def visit(self, node):
node.parent = self.parent
self.parent = node
node = super().visit(node)
if isinstance(node, ast.AST):
self.parent = node.parent
return node
class ReplaceElse(ParentLinks):
def visit_If(self, node: ast.If):
super().generic_visit(node) #Visit child nodes
#Create a temp variable for current If block in the form of an assign node
temp_var_id = f'_booIf{node.col_offset}'
assign_if_node = ast.Assign(
targets=[ast.Name(id=temp_var_id, ctx=ast.Store())],
value=ast.Constant(value=False)
)
p = node.parent #Get the parent of the current If node
if isinstance(p, ast.Module): #Parent is the module
pos = p.body.index (node) #Find the index of current node in the module (0 in our example)
#Insert the assign node at that index
p.body = [assign_if_node] + p.body
#p.body.insert(pos, assign_if_node) #Triggers infinite recursion
#ast.increment_lineno(node, n=1)
#ast.fix_missing_locations(node)
else:
... #Extra logic for other types of parents
return node
tree = ast.parse(code)
print(ast.dump(tree, indent=' '))
new_tree = ReplaceElse().visit(tree)
new_tree = ast.fix_missing_locations(new_tree)
print(ast.dump(new_tree, indent=' '))
print(ast.unparse(new_tree)) #Print transformed code
I tried 2 ways of inserting a node in the Module body.
p.body.insert(pos, assign_if_node) #Triggers infinite recursion
Using ast.increment_lineno and ast.fix_missing_locations(node) did not help.
p.body = [assign_if_node] + p.body #This seems to work
So, what is the correct way of updating a node which is not the current node being visited by NodeTransformer? Can these other nodes even be updated in this way, and if not, how can this be done?
Edit: I'm using Python 3.11
Your first approach using p.body.insert
approach did exactly what you told it to do, however the default NodeTransformer
visitor does not track whether or not the current node that's being visited is has inserted a new node before the current node, so the underlying for
loop (via the internal iter_fields
, used by NodeTransformer.generic_visit
), will not realize that a new element that has inserted before the current node so the loop never advances, and this in fact results in an infinite loop (not infinite recursion).
A naive method for naive visitors under this approach will be to simply figure out a way to mark visited nodes and not process them, but this potentially have unwanted side-effects or additional memory usage.
Your second approach using p.body = [assign_if_node] + p.body
fails to trigger the infinite loop error is simply because this approach replaced the original with a new assignment that includes the assign_if_node
, all the while the iterator in the visitor is still referencing the initial input, thus no new elements will be presented to the iterator that would extend what it was looking at. (Also this approach only prepends all the assignments, and does not put the assignment immediately before the if
block, differing to the body.insert
approach being taken.)
The correct method is to check how the visitor should be used and use it as documented - the ast.NodeTransformer
documented that its intended usage is that it will "use the return value of the visitor methods to replace or remove the old node", and that "the visitor may also return a list of nodes" whenever the nodes "were part of a collection of statements". So taking that into account, this is what should be done:
class ReplaceElse(ast.NodeTransformer):
def visit_If(self, node: ast.If):
super().generic_visit(node)
temp_var_id = f'_booIf{node.col_offset}'
assign_if_node = ast.Assign(
targets=[ast.Name(id=temp_var_id, ctx=ast.Store())],
value=ast.Constant(value=False),
)
return [
assign_if_node,
node,
]
Noticed how this example does not manually manipulate its parent, but rather simply return a list with the new ast.Assign
node before the provided node, which is what the original code in the question should have done (though it does not achieve the desired output, as the code for that was not provided with the question).