Search code examples
pythontheanosymbolic-math

How do we get the list of variables that a theano expression depends on?


In sympy, I would something like:

In [6]: import sympy as sp

In [7]: sp.var('x, y')
Out[7]: (x, y)

In [8]: X = x + y

In [9]: X.free_symbols
Out[9]: {y, x}

to get the variables X depends on. This is super convenient, because if we want to do a lambdify afterwards:

f = sp.lambdify(tuple(X.free_symbols), X)

I would like to do something similar with theano:

import theano
import theano.tensor as T
x, y = T.dvectors('x', 'y')
X = x + y
f = theano.function([x, y ], X)

But, instead of providing [x,y] I would like to access directly the list of variables needed to create the theano.function

Is it possible ? If so, I did not find it in the theano doc, so any help or a link would be appreciated :)


Solution

  • Some theano functions isn't well doc-ed since mostly for internal use.

    import theano
    import theano.tensor as T
    x, y = T.vectors('xy')
    z = x+y
    theano.gof.graph.inputs([z])
    

    Output:

    [x, y]