Search code examples
pythonpytorchnumpy-ndarraytensorbinary-matrix

Creating a pytorch tensor binary mask using specific values


I am given a pytorch 2-D tensor with integers, and 2 integers that always appear in each row of the tensor. I want to create a binary mask that will contain 1 between the two appearances of these 2 integers, otherwise 0. For example, if the integers are 4 and 2 and the 1-D array is [1,1,9,4,6,5,1,2,9,9,11,4,3,6,5,2,3,4], the returned mask will be: [0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0]. Is there any efficient and quick way to compute this mask without iterations?


Solution

  • Based completely on the previous solution, here is the revised one:

    import torch
    
    vals=[2,8]#let's assume those are the constant values that appear in each row
    
    #target tensor
    m=torch.tensor([[1., 2., 7., 8., 5., 2., 6., 5., 8., 4.],
        [4., 7., 2., 1., 8., 2., 6., 5., 6., 8.]])
    
    #let's find the indexes of those values
    k=m==vals[0]
    p=m==vals[1]
    
    v=(k.int()+p.int()).bool()
    nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],4)
    
    #let's create a tiling of the indexes
    q=torch.arange(m.shape[1])
    q=q.repeat(m.shape[0],1)
    
    #you only need two masks, no matter the size of m. see explanation below
    msk_0=(nz_indexes[:,0].repeat(m.shape[1],1).transpose(0,1))<=q
    msk_1=(nz_indexes[:,1].repeat(m.shape[1],1).transpose(0,1))>=q
    msk_2=(nz_indexes[:,2].repeat(m.shape[1],1).transpose(0,1))<=q
    msk_3=(nz_indexes[:,3].repeat(m.shape[1],1).transpose(0,1))>=q
    
    final_mask=msk_0.int() * msk_1.int() + msk_2.int() * msk_3.int()
    
    print(final_mask)
    

    and we finally get

    tensor([[0, 1, 1, 1, 0, 1, 1, 1, 1, 0],
            [0, 0, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)