Search code examples
pythonnumpyscipysparse-matrix

Numpy: create new array by combining subsets of previous array


I have a binary sparse CSR array. I would like to create a new array by combining columns from this original array. That is, I have a list of "column groups": [[1,10,3], [5,54,202], [12,199], [5], ...]

For each of these "column groups" I want to combine columns from the original array with an OR operation (np.max works for this) and add the combined column to a new matrix.

My current solution is to use hstack but it's quite slow:

for cg in column_groups:
    tmp = np.max(data_orig[:,cg].toarray(), axis=1, keepdims=True)
    data = np.hstack((data, tmp))

Solution

  • Well you are basically selecting the max column at each iteration. So, we can select all columns and then use np.maximum.reduceat to have "intervaled-maximum" columns and hence give us a vectorized solution, like so -

    def grouped_max(data_orig, column_groups):
        cols = np.hstack((column_groups))
        clens = np.hstack((0,np.cumsum(map(len,column_groups))[:-1]))
        all_data = data_orig[:,cols].toarray()
        return np.maximum.reduceat(all_data, clens,axis=1)
    

    For python 3.x version, we need to compute clens, like so -

    clens = np.hstack((0,np.cumsum(list(map(len,column_groups)))[:-1]))
    

    Since, the loopy version is iterating along groups, this vectorized solution would show its benefits when working with a large number of groups.

    Sample run -

    In [303]: # Setup sample csr matrix
         ...: a = np.random.randint(0,3,(12,28))
         ...: data_orig = sparse.csr_matrix(a)
         ...: 
         ...: # Random column IDs
         ...: column_groups = [[1,10,3], [5,14],[2]]
         ...: 
         ...: data = np.empty((12,0),dtype=int)
         ...: for cg in column_groups:
         ...:     tmp = np.max(data_orig[:,cg].toarray(), axis=1, keepdims=True)
         ...:     data = np.hstack((data, tmp))
         ...:     
    
    In [304]: out = grouped_max(data_orig, column_groups)
    
    In [305]: # Verify results between original and propsed ones
         ...: print np.allclose(out, data)
    True