Search code examples
keraspytorchtorch

Initialize weights and bias in torch


What is the equivalent command for the below keras code in Pytorch

Dense(64, kernel_initializer='he_normal', bias_initializer='zeros', name='uhat_digitcaps')(d5)

How to I initialize weights and bias?

Thanks!


Solution

  • class Net(nn.Module):
        def __init__(self, in_channels, out_channels):
            self.linear = nn.Linear(in_channels, 64)
            nn.init.kaiming_normal_(self.linear.weight, mode='fan_out')
            nn.init.constant_(self.linear.bias, 0)