Search code examples
pythonabstract-syntax-tree

python - ast recognize imported names


I am trying to replace names in a python file. For this I have written a class that recognizes and replaces names. Everything works fine. However, the function names of imported modules are also replaced. And for method calls, only the method name is replaced. Can you help me to solve this problems?

Code:

import ast
import random
import string

def generate_name():
    return "".join(random.choice(string.ascii_lowercase) + random.choice(string.ascii_lowercase + string.digits) for _ in range(6))

class ReplaceNames(ast.NodeTransformer):
    def __init__(self):
        self.replace_dict = {}

    def get_name(self, name):
        if not name in self.replace_dict:
            self.replace_dict[name] = generate_name()

        return self.replace_dict[name]

    def check_name(self, name):
        if not (
            name.startswith("__") and
            name.endswith("__")
        ):
            return True

        else:
            return False

    def visit_Name(self, node: ast.Name):
        if isinstance(node.ctx, ast.Store) and self.check_name(node.id):
            node.id = self.get_name(node.id)

        self.generic_visit(node)
        return node

    def visit_Attribute(self, node: ast.Attribute):
        if self.check_name(node.attr):
            node.attr = self.get_name(node.attr)

        self.generic_visit(node)
        return node

    def visit_FunctionDef(self, node: ast.FunctionDef):
        if self.check_name(node.name):
            node.name = self.get_name(node.name)

        self.generic_visit(node)
        return node

    def visit_ClassDef(self, node: ast.ClassDef):
        if self.check_name(node.name):
            node.name = self.get_name(node.name)

        self.generic_visit(node)
        return node

source = """
import time

class TestClass:
    def __init__(self):
        self.a = 0

    def add(self, value: int):
        self.a += value

    def subtract(self, value: int):
        self.a -= value

a = TestClass()
a.add(5)
time.sleep(5)
a.subtract(3)
"""
tree = ast.parse(source)
replacer = ReplaceNames()
tree = replacer.visit(tree)
print(ast.unparse(tree))

output:

import time

class z8cyt3kfw9uu:

    def __init__(self):
        self.pwk7zlx0mxe0 = 0

    def flsmoiwyeqwq(self, value: int):
        self.pwk7zlx0mxe0 += value

    def gnorpbkmaiy4(self, value: int):
        self.pwk7zlx0mxe0 -= value
pwk7zlx0mxe0 = TestClass()
a.flsmoiwyeqwq(5)
time.u6q7sum5gle9(5)
a.gnorpbkmaiy4(3)

Solution

  • The main problem is that you're not looking at the primary names of attribute references. Use node.value to get the primary (where node is an Attribute), and if it's a Name, use its .id.

    Here's a quick example you can put in your code to show where they come up. As well, I changed the outer a to t for clarity.

    ...
        def visit_Attribute(self, node: ast.Attribute):
            # DEBUG >
            orig_id = node.value.id
            orig_attr = node.attr
            # < DEBUG
    
            if self.check_name(node.attr):
                node.attr = self.get_name(node.attr)
    
            # DEBUG >
            print(
                'DEBUG Attribute node:',
                type(node.ctx).__name__,
                f'{orig_id}.{orig_attr}',
                '->',
                f'{node.value.id}.{node.attr}')
            # < DEBUG
    
            self.generic_visit(node)
            return node
    ...
    
    source = """
    import time
    
    class TestClass:
        def __init__(self):
            self.a = 0
    
        def add(self, value: int):
            self.a += value
    
        def subtract(self, value: int):
            self.a -= value
    
    t = TestClass()
    t.add(5)
    time.sleep(5)
    t.subtract(3)
    """
    tree = ast.parse(source)
    replacer = ReplaceNames()
    replacer.visit(tree)  # Assignment not needed
    print('---')
    print(ast.unparse(tree))
    

    Example output:

    DEBUG Attribute node: Store self.a -> self.pskhlmhsn9qw
    DEBUG Attribute node: Store self.a -> self.pskhlmhsn9qw
    DEBUG Attribute node: Store self.a -> self.pskhlmhsn9qw
    DEBUG Attribute node: Load t.add -> t.nqfpsiu9agzr
    DEBUG Attribute node: Load time.sleep -> time.n2wnu3seoaif
    DEBUG Attribute node: Load t.subtract -> t.xdhjcmyexwa3
    ---
    import time
    
    class sci7fnr9prnm:
    
        def __init__(self):
            self.pskhlmhsn9qw = 0
    
        def nqfpsiu9agzr(self, value: int):
            self.pskhlmhsn9qw += value
    
        def xdhjcmyexwa3(self, value: int):
            self.pskhlmhsn9qw -= value
    p4kjjowow8vf = TestClass()
    t.nqfpsiu9agzr(5)
    time.n2wnu3seoaif(5)
    t.xdhjcmyexwa3(3)
    

    Once you have that fixed, to avoid changing imported names, you can compare the .id to the list of imports. For a starting point, see Python easy way to read all import statements from py module