Search code examples
pythonpython-3.xpytorchtorch

How to understand "torch.randn()" size* parameter arguments?


From what I understand, torch.randn(layers/depth, rows, columns), which can be seen when executing: torch.randn(2, 3, 3) ==> 2 layers (3x3) matrix:

tensor([[[ 1.4838,  1.2926,  1.6147],
     [ 0.7923,  0.6414, -0.2676],
     [-0.1949,  0.3859, -0.6940]],

    [[ 0.2454, -1.9215, -0.3078],
     [ 0.8544,  0.9726,  0.0330],
     [ 0.3579,  0.8247,  2.1288]]])

But what does adding an extra term in the size* parameter imply? As in: torch.randn(2, 1, 3, 3)

tensor([[[[ 0.6206, -1.3697, -0.2267],
      [ 1.0511,  2.3375, -0.9598],
      [-0.8148, -0.0911, -2.1211]]],


    [[[ 0.0659,  1.0764,  0.6150],
      [-1.7226,  0.5038, -0.9544],
      [-0.6447, -0.3325,  0.2048]]]])

What did the "1" add into the Tensor created?


Solution

  • Each number you are introducing refers to a dimension of the matrix. It is hard for humans to visualize more than 3 dimensions, but computers are fine with it.

    In this particular case, you can think about the extra dimension as something like a batch size.