Search code examples
pythonnumpymachine-learningclassificationsparse-matrix

Replace specific values in a matrix using Python


I have a m x n matrix where each row is a sample and each column is a class. Each row contains the soft-max probabilities of each class. I want to replace the maximum value in each row with 1 and others with 0. How can I do it efficiently in Python?


Solution

  • I think the best answer to your particular question is to use a matrix type object.

    A sparse matrix should be the most performant in terms of storing large numbers of these matrices of large sizes in a memory friendly way, given that most of the matrix is populated with zeroes. This should be superior to using numpy arrays directly especially for very large matrices in both dimensions, if not in terms of speed of computation, in terms of memory.

    import numpy as np
    import scipy       #older versions may require `import scipy.sparse`
    
    matrix = np.matrix(np.random.randn(10, 5))
    maxes = matrix.argmax(axis=1).A1           
                          # was .A[:,0], slightly faster, but .A1 seems more readable
    n_rows = len(matrix)  # could do matrix.shape[0], but that's slower
    data = np.ones(n_rows)
    row = np.arange(n_rows)
    sparse_matrix = scipy.sparse.coo_matrix((data, (row, maxes)), 
                                            shape=matrix.shape, 
                                            dtype=np.int8)
    

    This sparse_matrix object should be very lightweight relative to a regular matrix object, which would needlessly track each and every zero in it. To materialize it as a normal matrix:

    sparse_matrix.todense()
    

    returns:

    matrix([[0, 0, 0, 0, 1],
            [0, 0, 1, 0, 0],
            [0, 0, 1, 0, 0],
            [0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0],
            [0, 0, 1, 0, 0],
            [0, 0, 0, 1, 0],
            [0, 1, 0, 0, 0],
            [1, 0, 0, 0, 0],
            [0, 0, 0, 1, 0]], dtype=int8)
    

    Which we can compare to matrix:

    matrix([[ 1.41049496,  0.24737968, -0.70849012,  0.24794031,  1.9231408 ],
            [-0.08323096, -0.32134873,  2.14154425, -1.30430663,  0.64934781],
            [ 0.56249379,  0.07851507,  0.63024234, -0.38683508, -1.75887624],
            [-0.41063182,  0.15657594,  0.11175805,  0.37646245,  1.58261556],
            [ 1.10421356, -0.26151637,  0.64442885, -1.23544526, -0.91119517],
            [ 0.51384883,  1.5901419 ,  1.92496778, -1.23541699,  1.00231508],
            [-2.42759787, -0.23592018, -0.33534536,  0.17577329, -1.14793293],
            [-0.06051458,  1.24004714,  1.23588228, -0.11727146, -0.02627196],
            [ 1.66071534, -0.07734444,  1.40305686, -1.02098911, -1.10752638],
            [ 0.12466003, -1.60874191,  1.81127175,  2.26257234, -1.26008476]])