stringpytorchexpressionevalforward

How to build a forward pass of a network in PyTorch based on a string expression?


I have a string expression: self.w0torch.sin(x)+self.w1torch.exp(x). How can I use this expression as the forward pass of a model in PyTorch? The class for instantiating a model is as follows:

class MyModule(nn.Module):
    def __init__(self,vector):
        super().__init__()
        self.s='self.w0*torch.sin(x)+self.w1*torch.exp(x)'

        w0=0.01*torch.rand(1,dtype=torch.float,requires_grad=True)
        self.w0 = nn.Parameter(w0)

        w1=0.01*torch.rand(1,dtype=torch.float,requires_grad=True)
        self.w1 = nn.Parameter(w1)

    def forward(self,x):
        return ????

For this self.w0torch.sin(x)+self.w1torch.exp(x) string expression, the architecture of the model is as follows:

The architecture of the model

I have tried the following method as the forward pass:

def forward(self,x):
    return eval(self.s)

Is this the best way to do the forward pass? Note that the string expression could be varying and I don't want to define a constant forward pass like:

 def forward(self,x):
    return self.w0*torch.sin(x)+self.w1*torch.exp(x)

Solution

  • I do not recommend using eval directly due to the following reasons:

    • Security: eval can execute any arbitrary code, which is a potential security risk, especially with untrusted input.
    • Performance: eval can be slower as it needs to parse and interpret the string each time it is called.
    • Debugging and Maintenance: Code that uses eval is often harder to understand, debug, and maintain.

    However, if the requirement is to have a dynamic expression for the forward pass where the expression can change, you can use a safer alternative to eval. One such alternative is using torch's built-in operations and dynamically constructing the computation graph. This can be done using Python's built-in functions like getattr and setattr. Here's an example of how you might implement this:

    import torch
    import torch.nn as nn
    
    class MyModule(nn.Module):
        def __init__(self, vector):
            super().__init__()
            self.s = 'self.w0*torch.sin(x)+self.w1*torch.exp(x)'
    
            w0 = 0.01 * torch.rand(1, dtype=torch.float, requires_grad=True)
            self.w0 = nn.Parameter(w0)
    
            w1 = 0.01 * torch.rand(1, dtype=torch.float, requires_grad=True)
            self.w1 = nn.Parameter(w1)
    
        def parse_expression(self, x, expression):
            terms = expression.split('+')
            result = 0.0
            for term in terms:
                parts = term.split('*')
                weight = getattr(self, parts[0].strip())
                operation = parts[1].split('(')[0].strip()
                operand = x
                operation_func = getattr(torch, operation)
                result += weight * operation_func(operand)
            return result
    
        def forward(self, x):
            return self.parse_expression(x, self.s)