Search code examples
pytorchlstm

Decreasing number of nodes each layers in torch.nn.lstm


Is there an easy way to decrease the number of nodes in each layer by a factor? I don't see this option on the documentation page, perhaps there is a similar function I can use though instead of manually defining each layer?

    self.lstm = nn.LSTM(
        input_size=input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        batch_first=True,
        dropout=0.2,
    )  # lstm

Solution

  • Not that I know of, but writing it from scratch is straightforward:

    def _constant_scale(initial: int, factor: int) -> int:
       return initial//factor
    
    class StackedLSTM(Module):
       def __init__(self, input_size: int, hidden_sizes: list[int], *args, **kwargs):
           super(StackedLSTM, self).__init__()
           self.layers = ModuleList([LSTM(input_size=xs, hidden_size=hs, *args, **kwargs) for xs, hs in zip([input_size] + hidden_sizes, hidden_sizes)])
    
       def forward(self, x: Tensor, hc: Optional[tuple[Tensor, Tensor]] = None) -> Tensor:
           for layer in self.layers:
               x, _ = layer(x, hc)
               hc = None
           return x
    
    hidden_sizes = [_constant_scale(300, 2**i) for i in range(3)]
    sltm = StackedLSTM(100, hidden_sizes)
    x = torch.rand(10, 32, 100)
    h = torch.rand(1, 32, 300)
    c = torch.rand(1, 32, 300)
    out = sltm(x, (h, c))
    print(out.shape) 
    # torch.Size([10, 32, 75])