Search code examples
pythonalgorithmneural-networkbackpropagationgradient-descent

Backpropagation with Momentum


I'm following this tutorial for implementing the Backpropagation algorithm. However, I am stuck at implementing momentum for this algorithm.

Without Momentum, this is the code for weight update method:

def update_weights(network, row, l_rate):
    for i in range(len(network)):
        inputs = row[:-1]
        if i != 0:
            inputs = [neuron['output'] for neuron in network[i - 1]]
        for neuron in network[i]:
            for j in range(len(inputs)):
                neuron['weights'][j] += l_rate * neuron['delta'] * inputs[j]
            neuron['weights'][-1] += l_rate * neuron['delta']

And below is my implementation:

def updateWeights(network, row, l_rate, momentum=0.5):
    for i in range(len(network)):
        inputs = row[:-1]
        if i != 0:
            inputs = [neuron['output'] for neuron in network[i-1]]
        for neuron in network[i]:
            for j in range(len(inputs)):
                previous_weight = neuron['weights'][j]
                neuron['weights'][j] += l_rate * neuron['delta'] * inputs[j] + momentum * previous_weight
            previous_weight = neuron['weights'][-1]
            neuron['weights'][-1] += l_rate * neuron['delta'] + momentum * previous_weight

This gives me a Mathoverflow error since the weights are exponentially becoming too large over multiple epochs. I believe my previous_weight logic is wrong for the update.


Solution

  • I'll give you a hint. You're multiplying momentum by the previous_weight in your implementation, which is another parameter of the network on the same step. This obviously blows up quickly.

    What you should do instead is remember the whole update vector, l_rate * neuron['delta'] * inputs[j], on the previous backpropagation step and add it up. It might look something like this:

    velocity[j] = l_rate * neuron['delta'] * inputs[j] + momentum * velocity[j]
    neuron['weights'][j] += velocity[j]
    

    ... where velocity is an array of the same length as network, defined with a bigger scope than updateWeights and initialized with zeros. See this post for details.