Search code examples
pytorch

How can I iterate through the entire tensor of an unknown shape in Pytorch?


Let's say I have an arbitrary NN.module. I want to iterate through all the learnable weights of this neural network and then modify every weight.

For example I want to do something like this (where model is Resnet50):

with torch.no_grad():
  for param in model.parameters():
    print(type(param), param.size())
    param[0][0][0][0] = 0.5
    print(param[0][0][0][0])
    break

<class 'torch.nn.parameter.Parameter'> torch.Size([64, 3, 7, 7])
tensor(0.5000, device='cuda:0', requires_grad=True)

However, I want to do this to all weights without knowing the dimensions of params. Is there anyway to do this or would I need to do some recursion?


Solution

  • You can iterate over a tensor of arbitary size and number of dimensions by flattening it.

    for param in model.parameters():
      shape = param.shape # store original shape
      flat = param.flatten()
      for i in range(len(flat)):
          flat[i] = ... modify item, unclear what you want to do...
    
      # return to original shape
      param = flat.view(shape)