Search code examples
pythonnumpyscipysparse-matrixlogarithm

Log-sum-exp trick on a sparse matrix in scipy


What's the best way to apply something like scipy.misc.logsumexp to a sparse matrix (for instance a scipy.sparse.csr_matrix), specifying one axis?

The point is to leave the zeros out from the computation.

UPDATE

It is better to specify that I'm looking for something that performs the log-sum-exp trick, doing a simply succession of exp elem-wise, summing the rows and doing a log elem-wise is trivial in scipy.sparse. Less trivial is computing in a clean way the max along rows and subtracting it as each element in the a sparse matrix row is subtracted the corresponding max vector elem (retaining a sparse matrix in the end).


Solution

  • The non-zero entries of a CSR matrix X are obtained by

    X[i].data
    

    and (a permutation of) the values of the actual row would be obtained by appending X.shape[1] - len(X[i].data) zeros to that.

    logsumexp(a) = max(a) + log(∑ exp[a - max(a)])
    

    for a vector a. Let's set b = X[i].data and k = X.shape[1] - len(X[i].data) and denote our earlier permuted row of X as

    (b, 0ₖ)
    

    using 0ₖ to denote a zero vector of length k and (⋅, ⋅) for concatenation. Then

    logsumexp((b, 0ₖ))
     = max((b, 0ₖ)) + log(∑ exp[(b, 0ₖ) - max((b, 0ₖ))])
     = max(max(b), 0) + log(∑ exp[(b, 0ₖ) - max(max(b), 0)])
     = max(max(b), 0) + log(∑ exp[b - max(max(b), 0)] + ∑ exp[0ₖ - max(max(b), 0)])
     = max(max(b), 0) + log(∑ exp[b - max(max(b), 0)] + k × exp[-max(max(b), 0)])
    

    So we get the algorithm

    def logsumexp_csr_row(x):
        data = x.data
        mx = max(np.max(data), 0)
        tmp = data - mx
        r = np.exp(tmp, out=tmp).sum()
        k = X.shape[1] - len(data)
        return mx + np.log(r + k * np.exp(-mx))
    

    for a CSR row vector. Extending this algorithm to the full matrix is easily done by a list comprehension, although a more efficient form would loop over the rows using the indptr:

    def logsumexp_csr_rows(X):
        result = np.empty(X.shape[0])
        for i in range(X.shape[0]):
            data = X.data[X.indptr[i]:X.indptr[i+1]]
            # fill in from logsumexp_csr_row
            result[i] = mx + np.log(r + k * np.exp(-mx))
        return result
    

    A column-wise version is much trickier; it's probably easiest to transpose the matrix and convert back to CSR.


    UPDATE Ok, I misunderstood the question: the OP is not interested in handling the zeros at all, so the above derivation is useless and the algorithm should be

    def logsumexp_row_nonzeros(X):
        result = np.empty(X.shape[0])
        for i in range(X.shape[0]):
            result[i] = logsumexp(X.data[X.indptr[i]:X.indptr[i+1]])
        return result
    

    This is just filling in the general scheme of row-wise operations on a CSR matrix. For column-wise, transpose, convert back to CSR and apply the above.