Search code examples
pythontheanotheano.scan

How does theano's scan function work?


Look at this code:

import theano
import numpy
import theano.tensor as T 
import numpy as np

x = T.dvector('x')
y = T.dvector('y')

def fun(x,a):
    return x+a

results, updates = theano.scan(fn=fun,sequences=dict(input=x), outputs_info=dict(initial=y, taps=[-3]))

h = [10.,20,30,40,50,60,70]
f = theano.function([x, y], results)
g = theano.function([y], y)

print(f([1],h))

I have changed outputs_info'taps to -2,-3,and so on, but the the result of code is the same [11.0], I can't understand. Somebody can explain it?

Another question.

import theano
import numpy
import theano.tensor as T 
import numpy as np

x = T.dvector('x')
y = T.dvector('y')

def fun(x,a,b):
    return x+a+b

results, updates = theano.scan(fn=fun,sequences=dict(input=x), outputs_info=dict(initial=y, taps=[-5,-3]))

h = [10.,20,30,40,50,60,70]
f = theano.function([x, y], results)
g = theano.function([y], y)

print(f([1,2,3,4],h))

The output is [41,62,83,85], how does 85 come?


Solution

  • Consider this variation on your code:

    x = T.dvector('x')
    y = T.dvector('y')
    
    def fun(x,a,b):
        return x+b
    
    results, updates = theano.scan(
        fn=fun,
        sequences=dict(input=x), 
        outputs_info=dict(initial=y, taps=[-5,-3])
    )
    
    h = [10.,20,30,40,50,60,70]
    f = theano.function([x, y], results)
    g = theano.function([y], y)
    
    print(f([1],h))
    

    Your result will be 31.

    • Change taps to [-5, -2] and your result changes to 41.
    • Change taps to [-4, -3] and your result changes to 21.

    This demonstrates how things are working:

    1. The largest negative number in taps is treated as h[0]
    2. All other taps are offset from that

    So when taps is [-5,-2] fun inputs a and b = 10 and 40 respectively.

    Update for new question

    taps actually indicates that the the function at time t depends on the output of the function at time t - taps.

    For instance, the Fibonacci sequence is defined by the function

    f1

    Here's how you'd implement the Fibonacci sequence with theano.scan:

    x = T.ivector('x')
    y = T.ivector('y')
    
    def fibonacci(x,a,b):
        return a+b
    
    results, _ = theano.scan(
        fn=fibonacci,
        sequences=dict(input=x), 
        outputs_info=dict(initial=y, taps=[-2,-1])
        )
    
    h = [1,1]
    f = theano.function([x, y], results)
    
    print(np.append(h, f(range(10),h)))
    

    However, theano.scan has a problem. If the function depends on prior output, what do you use as the prior output for the first iteration?

    The answer is the initial input, h in your case. But in your case h is longer than you need it to be, you only need it to be 5 elements long (because the largest taps is -5 in your case). After using the required 5 elements of h, your function switches over to the actual output from your function.

    Here's a simplified trace of what's happening in your code:

    1. output[0] = x[0] + h[0] + h[2] = 41
    2. output[1] = x[1] + h[1] + h[3] = 62
    3. output[2] = x[2] + h[2] + h[4] = 83
    4. output[3] = x[3] + h[3] + output[0] = 85

    You'll see, at time = 4, we have an output from the function for time 4-3, and that output is 41. And since we have that output, we need to use it, because the function is defined as using prior outputs. So we just ignore the rest of h.