Search code examples
pythonextractabstract-syntax-treereinforcement-learning

AST extraction of parameters from multiple formats RL scripts


I have multiple implementations of RL algorithms, from which I am trying to extract parameters, their data type and value.

However, these implementations differ from script to script, sometimes the params are defined like this:

parser.add_argument("--env-id", type=str, default="pong_v3",
        help="the id of the environment") \\

    parser.add_argument("--total-timesteps", type=int, default=20000000,
        help="total timesteps of the experiments") 

    parser.add_argument("--learning-rate", type=float, default=2.5e-4,
        help="the learning rate of the optimizer")

and also like this:

class MAA2C(Agent):
def __init__(self, env, n_agents, state_dim, action_dim,
                 memory_capacity=10000, max_steps=None,
                 roll_out_n_steps=10,
                 reward_gamma=0.99, reward_scale=1., done_penalty=None,
                 actor_hidden_size=32, critic_hidden_size=32,
                 actor_output_act=nn.functional.log_softmax,
                 use_cuda=True, training_strategy="cocurrent",
                 actor_parameter_sharing=False, critic_parameter_sharing=False):

I am bit lost in how to integrate extraction from all such formats, right now I define my parameters in a text file, and the extraction is based on the current format:

env_id: str = "CartPole-v1"
    total_timesteps: int = 500000
    learning_rate: float = 2.5e-4
    num_envs: int = 4
    num_steps: int = 128
    anneal_lr: bool = True

This is my current code as of now:

import ast
import sys
from tabulate import tabulate

def extract_parameters_with_values_from_file(file_path: str) -> dict:
    with open(file_path, 'r') as file:
        source_code = file.read()

    parameter_values = {}

    tree = ast.parse(source_code)
    for node in ast.walk(tree):
        if isinstance(node, ast.AnnAssign):
            parameter_name = node.target.id
            if isinstance(node.annotation, ast.Name):
                data_type = node.annotation.id
            else:
                data_type = str(ast.dump(node.annotation))
            if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name) and node.value.func.id == 'int':
                parameter_value = node.value.args[0].n
            elif isinstance(node.value, ast.Constant):
                parameter_value = node.value.value
            elif isinstance(node.value, ast.NameConstant):
                parameter_value = node.value.value
            else:
                try:
                    parameter_value = ast.literal_eval(node.value)
                except (ValueError, SyntaxError):
                    parameter_value = ast.dump(node.value)
            parameter_values[parameter_name] = (data_type, parameter_value)

    return parameter_values

def read_parameters_from_txt(file_path: str) -> list:
    with open(file_path, 'r') as file:
        parameters = file.read().splitlines()
    return parameters

def extract_parameters_with_values(parameter_names: list, python_file_path: str) -> list:
    parameter_values = extract_parameters_with_values_from_file(python_file_path)
    extracted_parameters = []

    for parameter_name in parameter_names:
        if parameter_name in parameter_values:
            data_type, value = parameter_values[parameter_name]
            extracted_parameters.append([parameter_name, data_type, value])

    return extracted_parameters

if __name__ == "__main__":
    if len(sys.argv) != 3:
        sys.exit(1)

    parameter_txt_path = sys.argv[1]
    python_file_path = sys.argv[2]

    parameter_names = read_parameters_from_txt(parameter_txt_path)
    extracted_parameters = extract_parameters_with_values(parameter_names, python_file_path)

    if not extracted_parameters:
        print("No parameters found in the source code")
    else:
        print(tabulate(extracted_parameters, headers=['Parameter', 'Data Type', 'Value'], tablefmt="github"))

These are examples of the file whose formats I need to include: Code1, Code2.

My params.txt file looks like this:

device
policy_noise
ent_coef
vf_coef
clip_coef
gamma
batch_size
stack_size
frame_size
max_cycles
total_episodes
env_id
learning_rate
total_timesteps
buffer_size
nums_envs

I am a bit stuck on this and any suggestions or ideas will be greatly appreciated.


Solution

  • Please take a look at the Adapter design pattern. Basically you have multiple entry points from which data is retrieved. Every entry point has a different form of data. So the data of every data point should first be converted to a consistent structure (pre-processing). From there you can convert it to the desired output.

    Refactoring Guru Adapter Design Pattern

    Create an adapter for the user input from the terminal and one for class parameters. You can derive the signature from a class constructor. Please check the link below for more information.

    Stackoverflow Reference

    I would advice against using a text file for reading the parameters. You might get it to work, but if parameters get adjusted in the code when libraries get updated your code will break and you have to update the text file manually every time. This is not sustainable. Try to obtain the parameters dynamically (automated).

    Good luck!