Search code examples
tensorflowneural-networkpytorchstatic-analysis

Is it possible to get the architecture of neural network built with Tensorflow and Pytorch using static analysis?


I'm currently analyzing hundreds of code repositories to identify parameter settings of ML algorithms. In this regard, I was wondering if it is possible to extract the architecture of neural networks that are built with Tensorflow and Pytorch using static analysis?

To clarify my problem, consider the development of a neural network with TF and Pytorch. Usually, a model is created by implementing a class that inherits from TF or Pytorch. Within the class, the architecture (e.g., layers) is specified. For example, see the code snipped below:

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
    super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

I was wondering if I can extract the architecture using static analysis. TF provides a function called summary() that prints a summary of a network, including its layers, output shape, and number of parameters. That is exactly what I want to extract with static analysis. The rationale using static analysis is that I analyze hundreds of code repositories and therefore it is not feasible to run the code for each repository.


Solution

  • You have to compute the Abstract Syntax Tree (AST) of the code that contains the model. For that, you can use an open-source parser.

    When you have to AST, you can traverse the graph and extract the model architecture by following tracking the data flow. That part is not straightforward, though. You probably need to look at a few examples and write good tests. But this will not work when you have certain parts of the model defined in other files. You need some way to of doing an interfile static analysis but it is hard.

    Hacky way to do it

    For each repo, you detect files that contain a class inheriting from nn.Module. You can do this by using the computing AST for each file or simply based on the file content. Now once you know the files containing a module, you create a python file. You import that file in this new file, define and instance of that class. Then you can use .summary function and write it to a file. Probably, there will be multiple nn.Module's per repo when the parts of models defined in several files. To find out the network, you can just take the longest summary.

    for repo in repositories:
      for file in repo.get_files():
        if containsNNModule(file):
          create_new_file_with_content(file)
          # The file content contains one line for importing file
          # One line for defining an instance of the module
          # One line to print the summary.
          # Create a python subprocess run the new file, capture its output and save it.