Search code examples
pythontensorflowskflow

Use batch_size in model_fn in skflow


I need to create a random variable inside my model_fn(), having shape [batch_size, 20].

I do not want to pass batch_size as an argument, because then I cannot use a different batch size for prediction.

Removing the parts which do not concern this question, my model_fn() is:

def model(inp, out):
    eps = tf.random_normal([batch_size, 20], 0, 1, name="eps"))) # batch_size is the 
    # value I do not want to hardcode

    # dummy example
    predictions = tf.add(inp, eps)
    return predictions, 1

if I replace [batch_size, 20] by inp.get_shape(), I get

ValueError: Cannot convert a partially known TensorShape to a Tensor: (?, 20)

when running myclf.setup_training().

If I try

def model(inp, out):
    batch_size = tf.placeholder("float", [])
    eps = tf.random_normal([batch_size.eval(), 20], 0, 1, name="eps")))

    # dummy example
    predictions = tf.add(inp, eps)
    return predictions, 1

I get ValueError: Cannot evaluate tensor using eval(): No default session is registered. Usewith sess.as_default()or pass an explicit session to eval(session=sess) (understandably, because I have not provided a feed_dict)

How can I access the value of batch_size inside model_fn(), while remaining able to change it during prediction?


Solution

  • I wasn't aware of the difference between Tensor.get_shape() and tf.shape(Tensor). The latter works:

    eps = tf.random_normal(tf.shape(inp), 0, 1, name="eps")))
    

    As mentionned in Tensorflow 0.8 FAQ:

    How do I build a graph that works with variable batch sizes?

    It is often useful to build a graph that works with variable batch sizes, for example so that the same code can be used for (mini-)batch training, and single-instance inference. The resulting graph can be saved as a protocol buffer and imported into another program.

    When building a variable-size graph, the most important thing to remember is not to encode the batch size as a Python constant, but instead to use a symbolic Tensor to represent it. The following tips may be useful:

    Use batch_size = tf.shape(input)[0] to extract the batch dimension from a Tensor called input, and store it in a Tensor called batch_size.

    Use tf.reduce_mean() instead of tf.reduce_sum(...) / batch_size.

    If you use placeholders for feeding input, you can specify a variable batch dimension by creating the placeholder with tf.placeholder(..., shape=[None, ...]). The None element of the shape corresponds to a variable-sized dimension.