Search code examples
tensorflowskflow

How do I get variables used in model/function pass to TensorFlowEstimator


How do I get variables (e.g., embedding table, RNN variables, etc.) after the model fitting, when using TensorFlowEstimator for fitting, such as in this skflow example? As tf.all_variables() returns empty list.


Solution

  • You can use get_variable_names() in estimator.

    Basically, you need to get graph and then call all_variables(). For example: with self._graph.as_default(): print([v.name for v in variables.all_variables()])