Given a part of an AST for some code, I need to remove particular default assignments from function definition. To be specific, I need to remove all the variables that are contained in a list vars_to_remove
from function definitions where these variables are used as parameter.
For example, take vars_to_remove = ['sum1']
and a function def do_smth(sum = sum1):
. Assume sum1
has been defined globally previously, and added to that list.
I fail to find an way to remove what I want without either removing the whole FunctionDef Node. Or being unable to uniquely identify the node I need to remove. See my attempts below.
I tried to solve the problem in two ways:
Parent way:
I override def visit_FunctionDef(self, node)
, I am able to access the parameter I need to remove. However, if I return None, this will remove the whole FunctionDef node from the tree, because this is the node that is passed in. While I need to only remove all the nodes corresponding to node.args.defaults[0]...node.args.defaults[n]
.
Child way:
I override visit_Name(self, node)
. When I do that, I am able to return None, removing the node. However, this deals with all Name nodes in the whole code (from which the AST is derived), not the ones defined exclusively within function definitions. I am removing id:'sum1'
, but I do not need to necessarily remove all of the other occurrences of Name nodes in the whole program where the id = sum1
!
I think I am missing an easy solution.
You can do this my modifying the function's arguments in the call to visit_FunctionDef
:
class Transformer(ast.NodeTransformer):
def visit_FunctionDef(self, node):
defaults = node.args.defaults
# Track args with and without defaults. We want to keep args
# without defaults. For args with defaults, we need to sync the
# defaults to be removed with their corresponding args.
args_without_defaults = node.args.args[: -len(defaults)]
args_with_defaults = node.args.args[-len(defaults) :]
# Filter out unwanted args and defaults.
retain = [
(a, d)
for (a, d) in zip(args_with_defaults, defaults)
if d.id not in vars_to_remove
]
node.args.args = args_without_defaults + [a for (a, d) in retain]
node.args.defaults = [d for (a, d) in retain]
return node
Given this source:
def func1(sum=sum1):
pass
def func2(a=spam, b=sum1):
pass
def func3(a=spam):
pass
def func4(a=spam, b=sum1, c=eggs):
pass
this result is output:
def func1():
pass
def func2(a=spam):
pass
def func3(a=spam):
pass
def func4(a=spam, c=eggs):
pass