Search code examples
pythonneural-networkdeep-learningtheanotheano.scan

efficient kernel implementation in theano


I have just implemented a Gaussian kernel in Theano. However when I tested it as part of a neural network, it takes too long. It seems that the kernel subtractions are not paralellized. The whole training of the network uses a single processing core. So, how to correctly induce Theano to split the kernel operation?

import theano.tensor as T
import numpy
import theano

batch_s=5
dims=10
hidd_s=3
out_s=2

missing_param = None #"ignore"

rng = numpy.random.RandomState(1234)
input = T.matrix("input")
X = numpy.asarray(rng.uniform(low=-2.1, high=5.0, size=(batch_s, dims)))

def layer(x):

    W=theano.shared(
        value=numpy.asarray(
            rng.uniform(low=0.001, high=1.0, size=(dims, hidd_s)),
                dtype=theano.config.floatX),
        name='W', borrow=True)

    S=theano.shared(
        value=numpy.asarray(
            rng.uniform(low=10.0, high=100.0, size=(hidd_s, )),
                dtype=theano.config.floatX),
        name='S', borrow=True)

    dot_H = theano.shared(
        value=numpy.zeros((batch_s, hidd_s), 
            dtype=theano.config.floatX), 
        name='dot_H', borrow=True)
    # This is the kernel operation. I have tested with single scan as well
    # as with two nested scans, but operations arenot splitted as in the 
    # case of the usual dot product T.dot().
    for i in range(batch_s):
        for j in range(hidd_s):
            dot_H = T.set_subtensor(dot_H[i,j], 
                     T.exp(-(W.T[j] - x[i]).norm(2) ** 2) / 2 * S[j] ** 2)
    return dot_H

layer_out = theano.function(
                            inputs=[input], 
                            outputs=layer(input), 
                            on_unused_input=missing_param
                            )
print layer_out(X)

Thak you very much.


Solution

  • Removing the loops will allow Theano to optimize the parallelization.

    First you can avoid the inner loop by doing:

    for i in range(batch_s):
        T.exp(-(W.T - X[i]).norm(2,axis=1) ** 2) / 2 * S ** 2)
    

    Then you can use map on the outer loop:

    import theano.tensor as T
    import numpy
    import theano
    import timeit
    
    start = timeit.default_timer()
    batch_s=5
    dims=10
    hidd_s=3
    out_s=2
    
    missing_param = None #"ignore"
    
    rng = numpy.random.RandomState(1234)
    input = T.matrix("input")
    X = numpy.asarray(rng.uniform(low=-2.1, high=5.0, size=(batch_s, dims)))
    
    
    
    W=theano.shared(
            value=numpy.asarray(
                rng.uniform(low=0.001, high=1.0, size=(dims, hidd_s)),
                    dtype=theano.config.floatX),
            name='W', borrow=True)
    
    S=theano.shared(
            value=numpy.asarray(
                rng.uniform(low=10.0, high=100.0, size=(hidd_s, )),
                    dtype=theano.config.floatX),
            name='S', borrow=True)
    
    
    f_func,f_updates = theano.map(lambda i : T.exp(-(W.T - i).norm(2,axis=1) ** 2) / 2 * S ** 2,input,[])
    
    
    layer_out = theano.function([input],                                                        
                              f_func,
                              updates=f_updates,
                  on_unused_input=missing_param,
                              allow_input_downcast=True)
    
    
    print layer_out(X.astype('float32'))
    
    stop = timeit.default_timer()
    
    print "running time: " + str(stop - start) 
    

    The output for the original code is:

    [[  1.83701953e-25   1.78982216e-26   9.22911484e-27]
     [  1.60078639e-17   9.21553384e-17   7.62476155e-14]
     [  8.13404350e-17   1.88481821e-17   2.44677516e-15]
     [  3.16093011e-29   1.49698827e-27   2.42876079e-27]
     [  9.57103818e-09   3.46683533e-12   6.66103154e-12]]
    running time: 1.30477905273
    

    With the new one:

    [[  1.83701953e-25   1.78982216e-26   9.22911484e-27]
     [  1.60078639e-17   9.21553384e-17   7.62476155e-14]
     [  8.13404350e-17   1.88481821e-17   2.44677516e-15]
     [  3.16093011e-29   1.49698827e-27   2.42876079e-27]
     [  9.57103818e-09   3.46683533e-12   6.66103154e-12]]
    running time: 0.589275121689