Search code examples
pytorchtensorpyro

slice Pytorch tensors which are saved in a list


I have the following code segment to generate random samples. The generated samples is a list, where each entry of the list is a tensor. Each tensor has two elements. I would like to extract the first element from all tensors in the list; and extract the second element from all tensors in the list as well. How to perform this kind of tensor slice operation

import torch
import pyro.distributions as dist
num_samples = 250
# note that both covariance matrices are diagonal
mu1 = torch.tensor([0., 5.])
sig1 = torch.tensor([[2., 0.], [0., 3.]])
dist1 = dist.MultivariateNormal(mu1, sig1)
samples1 = [pyro.sample('samples1', dist1) for _ in range(num_samples)]

samples1

enter image description here


Solution

  • I'd recommend torch.cat with a list comprehension:

    col1 = torch.cat([t[0] for t in samples1])
    col2 = torch.cat([t[1] for t in samples1])
    

    Docs for torch.cat: https://pytorch.org/docs/stable/generated/torch.cat.html

    ALTERNATIVELY

    You could turn your list of 1D tensors into a single big 2D tensor using torch.stack, then do a normal slice:

    samples1_t = torch.stack(samples1)
    col1 = samples1_t[:, 0]  # : means all rows
    col2 = samples1_t[:, 1]
    

    Docs for torch.stack: https://pytorch.org/docs/stable/generated/torch.stack.html