Search code examples
pytorchconv-neural-networkactivation-function

Trainable beta in swish activation function, CNN, torch


I am using Swish activation function, with trainable 𝛽 parameter according to the paper SWISH: A Self-Gated Activation Function paper by Prajit Ramachandran, Barret Zoph and Quoc V. Le. I am using LeNet-5 CNN as a toy example on MNIST to train 'beta' instead of using beta = 1 as present in nn.SiLU(). I am using PyTorch 2.0 and Python 3.10. The example code is:

class LeNet5(nn.Module):
    def __init__(self, beta = 1.0):
        super(LeNet5, self).__init__()
        
        b = torch.tensor(data = beta, dtype = torch.float32)
        self.beta = torch.autograd.Variable(b, requires_grad = True)
        
        self.conv1 = nn.Conv2d(
            in_channels = 1, out_channels = 6, 
            kernel_size = 5, stride = 1,
            padding = 0, bias = False 
        )
        self.bn1 = nn.BatchNorm2d(num_features = 6)
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.conv2 = nn.Conv2d(
            in_channels = 6, out_channels = 16,
            kernel_size = 5, stride = 1,
            padding = 0, bias = False
        )
        self.bn2 = nn.BatchNorm2d(num_features = 16)
        self.fc1 = nn.Linear(
            in_features = 256, out_features = 120,
            bias = True
        )
        self.bn3 = nn.BatchNorm1d(num_features = 120)
        self.fc2 = nn.Linear(
            in_features = 120, out_features = 84,
            bias = True
        )
        self.bn4 = nn.BatchNorm1d(num_features = 84)
        self.fc3 = nn.Linear(
            in_features = 84, out_features = 10,
            bias = True
        )
        
        self.initialize_weights()

        
    def initialize_weights(self):
        for m in self.modules():
            # print(m)
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                
                # Do not initialize bias (due to batchnorm)-
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                
            elif isinstance(m, nn.BatchNorm2d):
                # Standard initialization for batch normalization-
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
    
    
    def swish_fn(self, x):
        return x * torch.sigmoid(x * self.beta)

    
    def forward(self, x):
        '''
        x = nn.SiLU()(self.pool1(self.bn1(self.conv1(x))))
        x = nn.SiLU()(self.pool1(self.bn2(self.conv2(x))))
        x = x.view(-1, 256)
        x = nn.SiLU()(self.bn3(self.fc1(x)))
        x = nn.SiLU()(self.bn4(self.fc2(x)))
        '''
        x = self.pool(self.bn1(self.conv1(x)))
        x = self.swish_fn(x = x)
        x = self.pool(self.bn2(self.conv2(x)))
        x = self.swish_fn(x = x)
        x = x.view(-1, 256)
        x = self.bn3(self.fc1(x))
        x = self.swish_fn(x = x)
        x = self.bn4(self.fc2(x))
        x = self.swish_fn(x = x)
        x = self.fc3(x)
        return x

While training the model, I am printing 'beta' as:

for epoch in range(1, num_epochs + 1):
    
    # One epoch of training-
    train_loss, train_acc = train_one_step(
        model = model, train_loader = train_loader,
        train_dataset = train_dataset
    )
    
    # Get validation metrics after 1 epoch of training-
    val_loss, val_acc = test_one_step(
        model = model, test_loader = test_loader,
        test_dataset = test_dataset
    )
    
    scheduler.step()
    current_lr = optimizer.param_groups[0]["lr"]
    
    print(f"Epoch: {epoch}; loss = {train_loss:.4f}, acc = {train_acc:.2f}%",
          f" val loss = {val_loss:.4f}, val acc = {val_acc:.2f}%,"
          f" beta = {model.beta:.6f} & LR = {current_lr:.5f}"
         )
    
    # Save training metrics to Python3 dict-
    train_history[epoch] = {
        'train_loss': train_loss, 'val_loss': val_loss,
        'train_acc': train_acc, 'val_acc': val_acc,
        'lr': current_lr
    }
    
    # Save model with best validation accuracy-
    if (val_acc > best_val_acc):
        best_val_acc = val_acc
        print(f"Saving model with highest val_acc = {val_acc:.2f}%\n")
        torch.save(model.state_dict(), "LeNet5_MNIST_best_val_acc.pth")

What am I doing wrong? Why isn't beta training as expected?


Solution

  • Just tried out my own comment, replacing autograd.Variable with nn.Parameter would work

    Variable is deprecated for many years, always avoid it whenever possible, it has been "merged" into Tensor, Parameter is a wrapper around Tensor, it ensures that the Tensor will be recorded by model.parameters(), (which will then be updated by optimizer)

    with Variable, beta will never be updated, with parameter, beta changed after BP
    here is the test code:

    import torch
    from torch import nn
    
    
    class LeNet5(nn.Module):
        def __init__(self, beta=1.0):
            super(LeNet5, self).__init__()
    
            b = torch.tensor(data=beta, dtype=torch.float32)
            # self.beta = torch.autograd.Variable(b, requires_grad=True)
            self.beta = torch.nn.Parameter(b, requires_grad=True)
    
            self.conv1 = nn.Conv2d(
                in_channels=1, out_channels=6,
                kernel_size=5, stride=1,
                padding=0, bias=False
            )
            self.bn1 = nn.BatchNorm2d(num_features=6)
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
            self.conv2 = nn.Conv2d(
                in_channels=6, out_channels=16,
                kernel_size=5, stride=1,
                padding=0, bias=False
            )
            self.bn2 = nn.BatchNorm2d(num_features=16)
            self.fc1 = nn.Linear(
                in_features=256, out_features=120,
                bias=True
            )
            self.bn3 = nn.BatchNorm1d(num_features=120)
            self.fc2 = nn.Linear(
                in_features=120, out_features=84,
                bias=True
            )
            self.bn4 = nn.BatchNorm1d(num_features=84)
            self.fc3 = nn.Linear(
                in_features=84, out_features=10,
                bias=True
            )
    
            self.initialize_weights()
    
        def initialize_weights(self):
            for m in self.modules():
                # print(m)
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight)
    
                    # Do not initialize bias (due to batchnorm)-
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
    
                elif isinstance(m, nn.BatchNorm2d):
                    # Standard initialization for batch normalization-
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
    
                elif isinstance(m, nn.Linear):
                    nn.init.kaiming_normal_(m.weight)
                    nn.init.constant_(m.bias, 0)
    
        def swish_fn(self, x):
            return x * torch.sigmoid(x * self.beta)
    
        def forward(self, x):
            '''
            x = nn.SiLU()(self.pool1(self.bn1(self.conv1(x))))
            x = nn.SiLU()(self.pool1(self.bn2(self.conv2(x))))
            x = x.view(-1, 256)
            x = nn.SiLU()(self.bn3(self.fc1(x)))
            x = nn.SiLU()(self.bn4(self.fc2(x)))
            '''
            x = self.pool(self.bn1(self.conv1(x)))
            x = self.swish_fn(x=x)
            x = self.pool(self.bn2(self.conv2(x)))
            x = self.swish_fn(x=x)
            x = x.view(-1, 256)
            x = self.bn3(self.fc1(x))
            x = self.swish_fn(x=x)
            x = self.bn4(self.fc2(x))
            x = self.swish_fn(x=x)
            x = self.fc3(x)
            return x
    
    
    if __name__ == '__main__':
        model = LeNet5()
        print(model.beta)
        optim = torch.optim.Adam(model.parameters())
        optim.zero_grad()
        out = model(torch.randn(32, 1, 128, 128))
        loss = out.mean()
        loss.backward()
        optim.step()
    
        print(model.beta)