The code snippet is like:
t = tensor.arange(1, K)
results, updates = theano.scan(fn=updatefunc, sequences=t, ...)
the scan process will iterate along t. However, when K<=1, t will be an empty range, then theano.scan() will crash. Is there any way to fix this?
You can use theano.ifelse.ifelse
to compute the scan only when the sequence has some elements in it. For example:
import theano
import theano.tensor as tt
import theano.ifelse
def step(x_t, s_tm1):
return s_tm1 + x_t
def compile():
K = tt.lscalar()
t = tt.arange(1, K)
zero = tt.constant(0, dtype='int64')
outputs, _ = theano.scan(step, sequences=[t], outputs_info=[zero])
output = theano.ifelse.ifelse(tt.gt(K, 1), outputs[-1], zero)
return theano.function([K], outputs=[output])
def main():
f = compile()
print f(3)
print f(2)
print f(1)
print f(0)
print f(-1)
main()
prints
[array(3L, dtype=int64)]
[array(1L, dtype=int64)]
[array(0L, dtype=int64)]
[array(0L, dtype=int64)]
[array(0L, dtype=int64)]