In this tutorial on using Pytorch to implement BiLSTM-CRF, author implements the following function. In specific, I am not quite understand what does max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
try to do?, or which kind of math formula it corresponds to?
# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec):
max_score = vec[0, argmax(vec)]
max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
return max_score + \
torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))
Looking at the code, it seems like vec
has a shape of (1, n)
.
Now we can follow the code line by line:
max_score = vec[0, argmax(vec)]
Using vec
in the location 0, argmax(v)
is just a fancy way of taking the maximal value of vec
. So, max_score
is (as the name suggests) the maximal value of vec
.
max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
Next, we want to subtract max_score
from each of the elements of vec
.
To do so the code creates a vector of the same shape
as vec
with all elements equal to max_score
.
First, max_score
is reshaped to have two dimensions using the view
command, then the expanded 2d vector is "stretched" to have length n
using the expand
command.
Finally, the log sum exp is computed robustly:
return max_score + \
torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))
The validity of this computation can be seen in this picture:
The rationale behind it is that exp(x)
can "explode" for x > 0
, therefore, for numerical stability, it is best to subtract the maximal value before taking exp
.
As a side note, I think a slightly more elegant way to do the same computation, taking advantage of broadcasting, would be
max_score, _ = vec.max(dim=1, keepdim=True) # take max along second dimension
lse = max_score + torch.log(torch.sum(torch.exp(vec - max_score), dim=1))
return lse
Also note that log sum exp is already implemented by pytorch: torch.logsumexp
.