There is no need to provide an input sequence to scan. You can instead specify the number of iterations via scan's n_steps
parameter. Optionally, you can also specify a condition under which the scan should stop early via theano.scan_module.until
.
So Python's range
function can be emulated using Theano's scan
without specifying an input sequence by figuring out how many iterations would be required to construct the requested sequence.
Here's an implementation of the range function based on Theano's scan
. The only complicated part is figuring out how many steps are required.
import numpy
import theano
import theano.tensor as tt
import theano.ifelse
def scan_range_step(x_tm1, step):
return x_tm1 + step
def compile_theano_range():
tt.arange
symbolic_start = tt.lscalar()
symbolic_stop = tt.lscalar()
symbolic_step = tt.lscalar()
n_steps = tt.cast(
tt.ceil(tt.abs_(symbolic_stop - symbolic_start) / tt.cast(tt.abs_(symbolic_step), theano.config.floatX)),
'int64') - 1
outputs, _ = theano.scan(scan_range_step, outputs_info=[symbolic_start], n_steps=n_steps,
non_sequences=[symbolic_step], strict=True)
outputs = theano.ifelse.ifelse(tt.eq(n_steps, 0), tt.stack(symbolic_start), outputs)
f = theano.function([symbolic_start, symbolic_stop, symbolic_step],
outputs=tt.concatenate([[symbolic_start], outputs]))
def theano_range(start, stop=None, step=1):
assert isinstance(start, int)
assert isinstance(step, int)
if step == 0:
raise ValueError()
if stop is None:
stop = start
start = 0
else:
assert isinstance(stop, int)
if start == stop:
return []
if stop < start and step > 0:
return []
if stop > start and step < 0:
return []
return f(start, stop, step)
return theano_range
def main():
theano_range = compile_theano_range()
python_range = range
for start in [-10, -5, -1, 0, 1, 5, 10]:
for stop in [-10, -5, -1, 0, 1, 5, 10]:
for step in [-3, -2, -1, 1, 2, 3]:
a = theano_range(start, stop, step)
b = python_range(start, stop, step)
assert numpy.all(numpy.equal(a, b)), (start, stop, step, a, b)
main()
Clearly this is a daft thing to do/use for real since Theano already provides a symbolic version of Python's range
function, i.e. theano.tensor.arange
. The built in implementation is also far more efficient than our scan
version because it doesn't use scan
, it uses a custom operation instead.
As a rule of thumb: you have to set a maximum number of iteration steps via the range
or the the n_steps
argument. You can set it to a very large number and then use theano.scan_module.until
to stop the iteration at an earlier stage if your stop condition is met.