I want to compute below formula in Matlab (E-step of EM for Multinomial Mixture Model),
g and θ are matrix , θ and λ have below constrains:
but count of m is more than 1593 and when compute product of θ, number get very small and Matlab save it with zero.
Anyone can simplifying the g formula or use other tricks to solve this problem?
update:
data: data.txt (after downloads, change file extension to 'mat')
code:
function EM(data)
%% initialize
K=2;
[N M]=size(data);
g=zeros(N,K);
landa=ones(K,1) .* 0.5;
theta = rand(M, K);
theta = bsxfun(@rdivide, theta, sum(theta,1))';
%% EM
for i=1:10
%% E Step
for n=1:N
normalize=0;
for k=1:K
g(n,k)=landa(k) * prod(theta(k,:) .^ data(n,:));
normalize=normalize + landa(k) * prod(theta(k,:) .^ data(n,:));
end
g(n,:)=g(n,:) ./ normalize;
end
%% M Step
for k=1:K
landa(k)=sum(g(:,k)) / N ;
for m=1:M
theta(k,m)=(sum(g(:,k) .* data(:,m)) + 1) / (sum(g(:,k) .* sum(data,2)) + M);
end
end
end
end
You can use computations on logarithms instead of the actual values to avoid underflow problems.
To start things, we slightly reformat the E step code:
for n = 1 : N
for k = 1 : K
g(n, k) = lambda(k) * prod(theta(k, :) .^ data(n, :));
end
end
g = bsxfun(@rdivide, g, sum(g, 2));
So instead of accumulating the denominator in an extra variable normalize
, we do the normalization in one step after both loops.
Now we introduce a variable lg
with contains the logarithm of g
:
for n = 1 : N
for k = 1 : K
lg(n, k) = log(lambda(k)) + sum(log(theta(k, :)) .* data(n, :));
end
end
g = exp(lg);
g = bsxfun(@rdivide, g, sum(g, 2));
So far, nothing is achieved. The underflow is just moved from within the loop to the conversion from lg
to g
via the exponential afterwards.
But, in the next line there is the normalization step, which means that the correct value of g
is not really necessary: All that is important is that different values have the correct ratios between them. This means we can divide all values that jointly enter a normalization by an arbitrary constant, without changing the end result. On the logarithmic scale, this means subtracting something, and we choose this something to be the arithmetic mean of lg
(corresponding to the harmonic mean of g
):
lg = bsxfun(@minus, lg, mean(lg, 2));
g = exp(lg);
g = bsxfun(@rdivide, g, sum(g, 2));
Via the subtraction, logarithmic values are moved from something like -2000, which doesn't survive the exponential, to something like +50 or -30. Values of g
now are sensible, and can be easily normalized to reach the correct end result.