Search code examples
pythonpytorchduplicatessparse-matrixtensor

Identifying and removing duplicate columns/rows in sparse binary matrix in PyTorch


Let's suppose we have a binary matrix A with shape n x m, I want to identify rows that have duplicates in the matrix, i.e. there is another index on the same dimension with the same elements in the same positions.

It's very important not to convert this matrix into a dense representation, since the real matrices I'm using are quite large and difficult to handle in terms of memory.

Using PyTorch for the implementation:

# This is just a toy sparse binary matrix with n = 10 and m = 100
A = torch.randint(0, 2, (10, 100), dtype=torch.float32).to_sparse()

Intuitively, we can perform the dot product of this matrix producing a new m x m matrix which contains in terms i, j, the number of 1s that the index i has in the same position of the index j at dimension 0.

B = A.T @ A # In PyTorch, this operation will also produce a sparse representation

At this point, I've tried to combine these values, comparing them with A.sum(0),

num_elements = A.sum(0)
duplicate_rows = torch.logical_and([
   num_elements[B.indices()[0]] == num_elements[B.indices()[1]],
   num_elements[B.indices()[0]] == B.values()
])

But this did not work!

I think that the solution can be written only by using operations on PyTorch Sparse tensors (without using Python loops and so on), and this could also be a benefit in terms of performance.


Solution

  • Here is an implementation where the duplicate rows in a binary sparse matrix are identified. It returns a mask of the rows to keep from the sparse matrix, but can easily be adjusted to give e.g. indices of duplicate rows. It also handles cases where 3 or more rows are duplicates of each other and only keeps 1 row per group (the lowest index row is always kept for simplicity).

    def get_unique_row_mask_sparse(A):
        # Number of matching 1s between each pair of rows
        B = A @ A.T
        
        # Number of 1s in each row
        row_sums = torch.sparse.sum(A, dim=1).to_dense()
        
        indices = B.indices()
        i, j = indices[0], indices[1]
        
        # Two rows i and j are duplicates if:
        # 1) B[i,j] == row_sums[i] == row_sums[j]
        # 2) i != j (exclude diagonal)
        # Moreover, we only keep the upper diagonal of the matrix to avoid duplicates 
        same_row_sums = row_sums[i] == row_sums[j]
        matches_equal_sums = B.values() == row_sums[i]
        not_diagonal = i != j
        upper_triangular = i < j
    
        is_duplicate_pair = same_row_sums & matches_equal_sums & not_diagonal & upper_triangular
        duplicate_pairs = indices[:, is_duplicate_pair]
    
        # For each duplicate pair (i,j), we keep row i
        keep_mask = torch.ones(A.size(0), dtype=torch.bool)
        for pair_idx in range(duplicate_pairs.size(1)):
            row_i, row_j = duplicate_pairs[:, pair_idx]
            keep_mask[row_j] = False
    
        return keep_mask
    

    Testing code:

    torch.manual_seed(42)
    A = torch.randint(0, 2, (10, 100), dtype=torch.float32).to_sparse()
    
    # Force some duplicate rows for testing
    A_dense = A.to_dense()
    A_dense[3] = A_dense[1]
    A_dense[6] = A_dense[1]
    A_dense[9] = A_dense[2]
    A = A_dense.to_sparse()
    
    keep_mask = get_unique_row_mask_sparse(A)
    print(keep_mask)
    

    Gives the result:

    tensor([ True,  True,  True, False,  True,  True, False,  True,  True, False])
    

    You can run the following to create a new sparse tensor from this.

    A_indices = A.indices()
    rows_mask = keep_mask[A_indices[0]]
    A_unique = torch.sparse_coo_tensor(
        A_indices[:, rows_mask],
        A.values()[rows_mask],
        (keep_mask.sum().item(), A.size(1))
    ).coalesce()