Search code examples
pythontheano

How to handle empty sequence with theano.scan()?


The code snippet is like:

t = tensor.arange(1, K)
results, updates = theano.scan(fn=updatefunc, sequences=t, ...)

the scan process will iterate along t. However, when K<=1, t will be an empty range, then theano.scan() will crash. Is there any way to fix this?


Solution

  • You can use theano.ifelse.ifelse to compute the scan only when the sequence has some elements in it. For example:

    import theano
    import theano.tensor as tt
    import theano.ifelse
    
    
    def step(x_t, s_tm1):
        return s_tm1 + x_t
    
    
    def compile():
        K = tt.lscalar()
        t = tt.arange(1, K)
        zero = tt.constant(0, dtype='int64')
        outputs, _ = theano.scan(step, sequences=[t], outputs_info=[zero])
        output = theano.ifelse.ifelse(tt.gt(K, 1), outputs[-1], zero)
        return theano.function([K], outputs=[output])
    
    
    def main():
        f = compile()
        print f(3)
        print f(2)
        print f(1)
        print f(0)
        print f(-1)
    
    
    main()
    

    prints

    [array(3L, dtype=int64)]
    [array(1L, dtype=int64)]
    [array(0L, dtype=int64)]
    [array(0L, dtype=int64)]
    [array(0L, dtype=int64)]