Search code examples
pythonstatisticsdata-scienceprobabilityexpectation-maximization

Expectation Maximization in Python


I'm tasked with implementing the expectation-maximization algorithm for a class I'm in. In the notes, my professor evaluated the iterative formulas used in the code, I've checked them and they're written correctly.

The question asks us to create synthetic data from a given model. This model is written in the gauss_mix() function below. My final output is not what it should be, though, and I'm stumped as to why.

import numpy as np
import pylab as plt

# Create a synthetic Dataset
def gauss_mix(x, pi1, mu1, mu2, sigma):
    term1 = pi1 * np.exp(-(x - mu1)**2 / 2*sigma**2)
    term2 = (1 - pi1) * np.exp(-(x - mu2)**2 / 2*sigma**2)
    return np.array(term1 + term2)

# Now we define the initial parameters
# The format of the list is: (pi_1, mu_1, mu_2, sigma)
initial_params = [.3, 5, 15, 2]

rand_position = np.random.rand(1,10000)*30
synth_data = gauss_mix(rand_position[0], initial_params[0], initial_params[1], initial_params[2], initial_params[3])

To see the plot, you can sort the rand_position[0] values before calculating gauss_mix. This produces the following plot: enter image description here

Moving on, I defined a couple of functions to assist with the calculations.

# Defining a couple of useful functions 
def gamma_1n_old(pi1_old, norm1, norm2):
    # probability of observing the dataset based
    # on the first gaussian. Formula given in the book
    numerator = pi1_old * norm1
    denominator = pi1_old * norm1 + (1-pi1_old) * norm2 
    return np.array(numerator / denominator)

def gamma_2n_old(pi1_old, norm1, norm2):
    # probability of observing the dataset based
    # on the second gaussian. Formula given in the book
    numerator = (1-pi1_old) * norm2
    denominator = pi1_old * norm1 + (1-pi1_old) * norm2
    return np.array(numerator / denominator)

def normal(x, mu, sigma):
    # Standard normal distribution equation
    numerator = np.exp(-(np.array(x)-mu)**2 / (2*sigma**2))
    denominator = np.sqrt(2*np.pi * sigma**2)
    return np.array(numerator / denominator) 

And I go through the loop here:

# now we can go through the EM loop

# start with a random set of parameters, the format of the list is: (pi_1, mu_1, mu_2, sigma)
rand = np.random.random(4) # 
params = [rand[0], rand[1]*10, rand[2]*10, rand[3]*10]

# initialize empty gamma lists
gamma1 = []
gamma2 = []

# make a copy of the synthetic data and use that to loop over
data = plot_synth_data.copy()

data_plot = [] # to get plots for specific iterations

for iteration in range(50):
    print(params)
    
    # get values for Normal_1 and Normal_2
    norm1 = normal(data, params[1], params[3])
    norm2 = normal(data, params[2], params[3])
#     print(norm1, norm2)

    # calculate the observation probability based on the old paramters
    gamma1_old = gamma_1n_old(params[0], norm1, norm2)
    gamma2_old = gamma_2n_old(params[0], norm1, norm2)
#     print(gamma1_old, gamma2_old)
    
    # need to append these to a new list so we can sum them across the whole time range
    gamma1.append(gamma1_old)
    gamma2.append(gamma2_old)
#     print(data)
#     print(np.sum(gamma1), np.sum(gamma1*data))
    
    # now to update the paramters for the next iteration
    params[0] = np.sum(gamma1_old) / np.sum(gamma1_old + gamma2_old)
    params[1] = np.sum(gamma1_old*data) / np.sum(gamma1_old)
    params[2] = np.sum(gamma2_old*data) / np.sum(gamma2_old)
    params[3] = np.sqrt(np.sum(gamma1_old * (data - params[1])**2) / np.sum(gamma1_old))
    
    # Just for convinience, we can plot every 7th iteration to visually check how it's changing
    if iteration % 7 == 0:
        plot = gauss_mix(data, params[0], params[1], params[2], params[3])
        data_plot.append(plot)  

The output of the print(params) statement is the following, I've omitted some lines as they don't change with successive iterations.

[0.1130842168240086, 3.401472765079545, 2.445209909135907, 2.3046528697572635]
[0.07054376684886957, 0.04341192273911035, 0.04067151364724695, 0.12585753071439582]
[0.07054303636195076, 0.04330910871714057, 0.040679319081395215, 0.12567545288855245]
[0.07054238762380395, 0.04321431848177363, 0.04068651514443456, 0.12550734898400692]
[0.07054180884360708, 0.043126645044752804, 0.04069317074867406, 0.125351664317294]
[0.07054129028636431, 0.04304531343415197, 0.040699344770810386, 0.12520706710362625]

I'm not sure what to make of the parameters here. For clarity, the list indices are [pi_1, mu_1, mu_2, sigma]. My initial guess is that I'm not using the data properly in the calculations, I'm not sure how else I'd go about doing it though.

Any advice or guidance is welcome. I'm not exactly looking for a full written-out solution, just advice on where my fault is. I'll keep my eye out for any questions to better clarify my issue.


Solution

  • I'm answering my own question here.

    The problem with my code was the way I was sampling from the data. The code below shows the correct method.

    # Create a synthetic Dataset
    def gauss_mix(pi1, mu1, mu2, sigma):
        if np.random.randn() < pi1:
            return mu1 + np.random.randn() * sigma
        else:
            return mu2 + np.random.randn() * sigma
    
    # Now we define the initial parameters
    # The format of the list is: (pi_1, mu_1, mu_2, sigma)
    initial_params = [.3, 5, 15, 2]
    
    sample = 10000
    synth_data = []
    for dat in range(sample):
        synth_data.append(gauss_mix( initial_params[0], initial_params[1], initial_params[2], initial_params[3]))
    

    Which, when plotting it, gives the following result:

    enter image description here