Search code examples
pythonpytorchneural-networkcasadi

Is there any solutions that help casadi MX variable to be used in pytorch nn.conv1d(convolution)?


I now have a Sequential Neural Network which is used for predict robot states. But I have a problem when implementing the NN into Casadi to solve an MPC problem. The error keeps warning me that I can not use Casadi MX variable in a Sequential NN which requires convolution process.

I have seen the repo l4casadi but it seems only supporting nn.linear but not nn.conv1d. Hopes to find a solution here and thanks for answering.


Solution

  • L4CasADi supports PyTorch Models exceeding linear layers (such as convolutions). L4CasADi supports all PyTorch Models, which are jit traceable/scriptable.

    L4CasADi Example with Convolution:

    import torch
    import numpy as np
    import l4casadi as l4c
    import casadi as cs
    
    
    # Create a model with convolutional layers
    class ConvModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
    
            self.conv1 = torch.nn.Conv2d(1, 32, 3, padding=1)
            self.conv2 = torch.nn.Conv2d(32, 64, 3, padding=1)
            self.conv3 = torch.nn.Conv2d(64, 64, 3, padding=1)
            self.fc1 = torch.nn.Linear(64 * 7 * 7, 128)
            self.fc2 = torch.nn.Linear(128, 1)
    
        def forward(self, x):
            x = x.reshape(-1, 1, 7, 7)
            x = torch.nn.functional.relu(self.conv1(x))
            x = torch.nn.functional.relu(self.conv2(x))
            x = torch.nn.functional.relu(self.conv3(x))
            x = x.view(-1, 64 * 7 * 7)
            x = torch.nn.functional.relu(self.fc1(x))
            x = self.fc2(x)
            return x
    
    
    x = np.random.randn(49).astype(np.float32)
    model = ConvModel()
    y = model(torch.tensor(x)[None])
    print(f'Torch output: {y}')
    
    l4c_model = l4c.L4CasADi(model, model_expects_batch_dim=True)
    
    x_sym = cs.MX.sym('x', 49)
    y_sym = l4c_model(x_sym)
    
    f = cs.Function('y', [x_sym], [y_sym])
    y = f(x)
    
    print(f'L4CasADi Output: {y}')