Search code examples
deep-learningpytorchlstm

regarding one code segment in computing log_sum_exp


In this tutorial on using Pytorch to implement BiLSTM-CRF, author implements the following function. In specific, I am not quite understand what does max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1]) try to do?, or which kind of math formula it corresponds to?

# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec):
    max_score = vec[0, argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + \
        torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

Solution

  • Looking at the code, it seems like vec has a shape of (1, n).
    Now we can follow the code line by line:

    max_score = vec[0, argmax(vec)]
    

    Using vec in the location 0, argmax(v) is just a fancy way of taking the maximal value of vec. So, max_score is (as the name suggests) the maximal value of vec.

    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    

    Next, we want to subtract max_score from each of the elements of vec. To do so the code creates a vector of the same shape as vec with all elements equal to max_score.
    First, max_score is reshaped to have two dimensions using the view command, then the expanded 2d vector is "stretched" to have length n using the expand command.

    Finally, the log sum exp is computed robustly:

     return max_score + \
            torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))
    

    The validity of this computation can be seen in this picture: enter image description here

    The rationale behind it is that exp(x) can "explode" for x > 0, therefore, for numerical stability, it is best to subtract the maximal value before taking exp.


    As a side note, I think a slightly more elegant way to do the same computation, taking advantage of broadcasting, would be

    max_score, _ = vec.max(dim=1, keepdim=True)  # take max along second dimension
    lse = max_score + torch.log(torch.sum(torch.exp(vec - max_score), dim=1))
    return lse
    

    Also note that log sum exp is already implemented by pytorch: torch.logsumexp.