Let us consider the following Python source code;
def package_data(pkg, roots):
data = []
for root in roots:
for dirname, _, files in os.walk(os.path.join(pkg, root)):
for fname in files:
data.append(os.path.relpath(os.path.join(dirname, fname), pkg))
return {pkg: data}
From this source code, I want to extract all the functions and API calls. I found a similar question and solution. I ran the solution given here and it generates the output [os.walk, data.append]
. But I am looking for the following output [os.walk, os.path.join, data.append, os.path.relpath, os.path.join]
.
What I understood after analyzing the following solution code, this can visit the every node before the first bracket and drop rest of the things.
import ast
class CallCollector(ast.NodeVisitor):
def __init__(self):
self.calls = []
self.current = None
def visit_Call(self, node):
# new call, trace the function expression
self.current = ''
self.visit(node.func)
self.calls.append(self.current)
self.current = None
def generic_visit(self, node):
if self.current is not None:
print("warning: {} node in function expression not supported".format(
node.__class__.__name__))
super(CallCollector, self).generic_visit(node)
# record the func expression
def visit_Name(self, node):
if self.current is None:
return
self.current += node.id
def visit_Attribute(self, node):
if self.current is None:
self.generic_visit(node)
self.visit(node.value)
self.current += '.' + node.attr
tree = ast.parse(yoursource)
cc = CallCollector()
cc.visit(tree)
print(cc.calls)
Can anyone please help me to modified this code so that this code can traverse the API calls inside the bracket?
N.B: This can be done using regex in python. But it requires a lot of manual labors to find out the appropriate API calls. So, I am looking something with help of Abstract Syntax Tree.
Not sure if this is the best or simplest solution but at least it does work as intended for your case:
import ast
class CallCollector(ast.NodeVisitor):
def __init__(self):
self.calls = []
self._current = []
self._in_call = False
def visit_Call(self, node):
self._current = []
self._in_call = True
self.generic_visit(node)
def visit_Attribute(self, node):
if self._in_call:
self._current.append(node.attr)
self.generic_visit(node)
def visit_Name(self, node):
if self._in_call:
self._current.append(node.id)
self.calls.append('.'.join(self._current[::-1]))
# Reset the state
self._current = []
self._in_call = False
self.generic_visit(node)
Gives for your example:
['os.walk', 'os.path.join', 'data.append', 'os.path.relpath', 'os.path.join']
The problem is that you have to do a generic_visit
in all visit
s to ensure you walk the tree properly. I also used a list as current
to join the (reversed) afterwards.
One case I found that doesn't work with this approach is on chained operations, for example: d.setdefault(10, []).append(10)
.
Just in case you're interested in how I arrived at that solution:
Assume a very simple implementation of a node-visitor:
import ast
class CallCollector(ast.NodeVisitor):
def generic_visit(self, node):
try:
print(node, node.id)
except AttributeError:
try:
print(node, node.attr)
except AttributeError:
print(node)
return super().generic_visit(node)
This will print a lot of stuff, however if you look at the result you'll see some patterns, like:
...
<_ast.Call object at 0x000001AAEE8FFA58>
<_ast.Attribute object at 0x000001AAEE8FFBE0> walk
<_ast.Name object at 0x000001AAEE8FF518> os
...
and
...
<_ast.Call object at 0x000001AAEE8FF160>
<_ast.Attribute object at 0x000001AAEE8FF588> join
<_ast.Attribute object at 0x000001AAEE8FFC50> path
<_ast.Name object at 0x000001AAEE8FF5C0> os
...
So first the call-node is visited, then the attributes (if any) and then finally the name. So you have to reset the state when you visit a call-node, append all attributes to it and stop if you hit a name node.
One could do it within the generic_visit
but it's probably better to do it in the methods visit_Call
, ... and then just call generic_visit
from these.
A word of caution is probably in order: This works great for simple cases but as soon as it becomes non-trivial this will not work reliably. For example what if you import a subpackage? What if you bind the function to a local variable? What if you call the result of a getattr
result? Listing the functions that are called by static analysis in Python is probably impossible, because beside the ordinary problems there's also frame-hacking and dynamic assignments (for example if some import or called function re-assigned the name os
in your module).