Search code examples
pythonregexparsingabstract-syntax-treevisitor-pattern

Replace references of variables with their assignments done later in Python file


I am converting a verilog file to python file using parsers and visitors. Problem is that verilog is a declarative language while python is an imperative one. So order of variable assignments doesn't matter in verilog but matters in python.

For e.g.

def F(XY_vars, util):
    i_0 = XY_vars[0, :]
    i_1 = XY_vars[1, :]
    i_2 = XY_vars[2, :]
    out = util.continuous_xor((w1),(i_2))
    w1 = util.continuous_xor((i_0),(i_1))
    return out

Observe that w1 is referenced before the assignment. If the value of w1 is directly passed to the function then the problem is solved. I know what all vars will be used and also their values.

Is there a way to perform this replacement of arguments with their values assigned anywhere in the file?

How can ast module be used here?

If I need to write python visitor for this, some help for how to do that would be appreciated.

For reference I'm adding the verilog file from which I generated the above python code.

// Test Sample 1: xor of 3 variables 

module formula ( i_0,i_1,i_2,out);
input i_0, i_1, i_2;
output out;
wire w1;
assign out = w1 ^ i_2;
assign w1 = i_0 ^ i_1;

endmodule

Thanks


Solution

  • To robustly handle syntax analysis across many input samples, a solution at minimum has to account for function parameter names and scope. This solution works in two parts: first, using the ast module, the original code snippet is traversed and all assignments and expression objects "missing" an assignment binding are saved. Then, the tree is traversed once again, this time, missing expressions are replaced with their assignment target (if the latter exists):

    import ast, itertools, collections as cl
    class AssgnCheck:
       def __init__(self, scopes = None):
          self.scopes, self.missing = scopes or cl.defaultdict(lambda :cl.defaultdict(list)), []
       @classmethod
       def eq_ast(cls, a1, a2):
          #check that two `ast`s are the same
          if type(a1) != type(a2):
             return False
          if isinstance(a1, list):
             return all(cls.eq_ast(*i) for i in itertools.zip_longest(a1, a2))
          if not isinstance(a1, ast.AST):
             return a1 == a2
          return all(cls.eq_ast(getattr(a1, i, None), getattr(a2, i, None)) 
                     for i in set(a1._fields)|set(a2._fields) if i != 'ctx')
       def has_bindings(self, t_ast, s_path):
          #traverse the scope stack and yield `ast`s from t_ast that do not have a value assigned to them
          for _ast in t_ast:
             if not any(any(AssgnCheck.eq_ast(_ast, b) for _, b in self.scopes[sid]['names']) for sid in s_path[::-1]):
                yield _ast
       def traverse(self, _ast, s_path = [1]):
          #walk the ast object itself
          _t_ast = None
          if isinstance(_ast, ast.Assign): #if assignment statement, add ast object to current scope
             self.scopes[s_path[-1]]['names'].append((True, _ast.targets[0]))
             self.scopes[s_path[-1]]['bindings'].append((_ast.targets[0], _ast.value))
             _ast = _ast.value
          if isinstance(_ast, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
             s_path = [*s_path, (nid:=(1 if not self.scopes else max(self.scopes)+1))]
             if isinstance(_ast, (ast.FunctionDef, ast.AsyncFunctionDef)):
                self.scopes[nid]['names'].extend([(False, ast.Name(i.arg)) for i in _ast.args.args])
                _t_ast = [*_ast.args.defaults, *_ast.body]
          self.missing.extend(list(self.has_bindings(_t_ast if _t_ast is not None else [_ast], s_path))) #determine if current ast object instance has a value assigned to it
          if _t_ast is None:
             _ast.s_path = s_path
             for _b in _ast._fields:
                if isinstance((b:=getattr(_ast, _b)), list):
                   for i in b:
                      self.traverse(i, s_path)
                elif isinstance(b, ast.AST):
                   self.traverse(b, s_path)
          else:
              for _ast in _t_ast:
                 _ast.s_path = s_path
                 self.traverse(_ast, s_path)
    

    Function to perform the substitutions:

    import copy
    def replace_vars(_ast, c_obj, sentinel):
       def ast_bindings(a, n, v, is_l = False):
          if not isinstance(v, ast.AST):
             return
          if v in c_obj.missing:
             c_obj.missing.remove(v)
             for sid in v.s_path[::-1]:
                if (k:=[y for x, y in c_obj.scopes[sid]['bindings'] if AssgnCheck.eq_ast(v, x)]):
                   sentinel.f = True
                   if not is_l:
                      setattr(a, n, copy.deepcopy(k[0]))
                   else:
                      a[n] = copy.deepcopy(k[0])
                   return
          replace_vars(v, c_obj, sentinel)
       if isinstance(_ast, ast.Assign):
          ast_bindings(_ast, 'value', _ast.value)
       else:
          for i in _ast._fields:
             if isinstance((k:=getattr(_ast, i)), list):
                for x, y in enumerate(k):
                   ast_bindings(k, x, y, True)
             else:
                ast_bindings(_ast, i, k)
    

    Putting it all together:

    s = """
    def F(XY_vars, util):
       i_0 = XY_vars[0, :]
       i_1 = XY_vars[1, :]
       i_2 = XY_vars[2, :]
       out = util.continuous_xor((w1),(i_2))
       w1 = util.continuous_xor((i_0),(i_1))
       return out
    """
    class Sentinel:
       def __init__(self):
          self.f = False
    
    def replace_preref(s):
       t = ast.parse(s)
       while True:
          a = AssgnCheck()
          a.traverse(t)
          s = Sentinel()
          replace_vars(t, a, s)
          if not s.f:
             break
       return ast.unparse(t)
    
    print(replace_preref(s))
    

    Output:

    def F(XY_vars, util):
        i_0 = XY_vars[0, :]
        i_1 = XY_vars[1, :]
        i_2 = XY_vars[2, :]
        out = util.continuous_xor(util.continuous_xor(i_0, i_1), i_2)
        w1 = util.continuous_xor(i_0, i_1)
        return out
    

    In the example above, the label w_1 that was originally being passed as the first parameter to util.continuous_xor has been replaced with w1's target assignment expression from below.

    Second test sample:

    s = """
    def F(XY_vars, util):    
        i_0 = XY_vars[0, :]    
        i_1 = XY_vars[1, :]    
        i_2 = XY_vars[2, :]    
        out = util.continuous_xor((w1),(i_2))    
        w2 = util.continuous_xor((i_0), (i_1))    
        w1 = util.continuous_xor((w2),(i_1))    
        return out
    """
    print(replace_preref(s))
    

    Output:

    def F(XY_vars, util):
        i_0 = XY_vars[0, :]
        i_1 = XY_vars[1, :]
        i_2 = XY_vars[2, :]
        out = util.continuous_xor(util.continuous_xor(util.continuous_xor(i_0, i_1), i_1), i_2)
        w2 = util.continuous_xor(i_0, i_1)
        w1 = util.continuous_xor(w2, i_1)
        return out