Search code examples
pythontheano

How to perform conditional updates on shared variables in Theano?


Is there a way to conditionally update the shared variables depending on the result of the current function. Eg.

g_W = T.grad(cost=classifier.cost,wrt=classifier.W)
updates=[(W,W-learning_rate*g_W)]
model = theano.function([index],outputs=cost,updates=updates)

In this model, I need to update the weight parameter only if the cost is greater than 0. There is a no_default_updates parameter in the function but it doesn't apply to 'updates' parameter.


Solution

  • You can use a symbolic conditional operation. Theano has two: switch and ifelse. switch is performed element-wise while ifelse works more like a conventional conditional. See the documentation for more.

    Here's an example that updates parameters only when the cost is positive.

    import numpy
    import theano
    import theano.tensor as tt
    
    
    def compile(input_size, hidden_size, output_size, learning_rate):
        w_h = theano.shared(numpy.random.standard_normal((input_size, hidden_size))
                            .astype(theano.config.floatX), name='w_h')
        b_h = theano.shared(numpy.random.standard_normal((hidden_size,))
                            .astype(theano.config.floatX), name='b_h')
        w_y = theano.shared(numpy.random.standard_normal((hidden_size, output_size))
                            .astype(theano.config.floatX), name='w_y')
        b_y = theano.shared(numpy.random.standard_normal((output_size,))
                            .astype(theano.config.floatX), name='b_y')
        x = tt.matrix()
        z = tt.vector()
        h = tt.tanh(theano.dot(x, w_h) + b_h)
        y = theano.dot(h, w_y) + b_y
        c = tt.sum(y - z)
        updates = [(p, p - tt.switch(tt.gt(c, 0), learning_rate * tt.grad(cost=c, wrt=p), 0))
                   for p in (w_h, b_h, w_y, b_y)]
        return theano.function([x, z], outputs=c, updates=updates)
    
    
    def main():
        f = compile(input_size=3, hidden_size=2, output_size=4, learning_rate=0.01)
    
    
    main()
    

    In this case either switch or ifelse could be used but switch is generally preferable in such cases because ifelse does not appear to be as well supported throughout the Theano framework and requires a special import.