Search code examples
python-3.xtensorflowdeep-learningtensorflow2.0

Apply tf.ensure_shape for multiple outputs


I have this code:

import tensorflow as tf
import numpy as np

def scale(X, a=-1, b=1, dtype='float32'):
    if a > b:
        a, b = b, a
    xmin = tf.cast(tf.math.reduce_min(X), dtype=dtype)
    xmax = tf.cast(tf.math.reduce_max(X), dtype=dtype)
    X = (X - xmin) / (xmax - xmin)
    scaled = X * (b - a) + a
    return scaled, xmin, xmax


a = np.ones((10, 20, 20, 2))

dataset = tf.data.Dataset.from_tensor_slices(a)


data = dataset.map(lambda x: tf.py_function(scale,
                                            [x], 
                                            (tf.float32, tf.float32, tf.float32)))

Until here it is ok, I receive :

data

<MapDataset shapes: (<unknown>, <unknown>, <unknown>), types: (tf.float32, tf.float32, tf.float32)>

Now, I have to use tf.ensure_shape, to create the shapes.

If for example the scale function returned only one value, scale, then I would do:

data = data.map(lambda x: tf.ensure_shape(x, [10, 20, 20, 2]))

Now that I have 3 output values?

So, I want to be able to use the result of the scale function that's why I am doing all these. If there is another way, I don't know.

scaled values, xmin and xmax


Solution

  • If it is just about transforming uknown shape to known shape, I think you can use tf.reshape method.

    def scale(X, a=-1, b=1, dtype='float32'):
        if a > b:
            a, b = b, a
        xmin = tf.cast(tf.math.reduce_min(X), dtype=dtype)
        xmax = tf.cast(tf.math.reduce_max(X), dtype=dtype)
        X = (X - xmin) / (xmax - xmin)
        scaled = X * (b - a) + a
        return scaled, xmin, xmax
    
    a = tf.random.uniform(shape=[10, 20, 20, 2], minval=1, maxval=5)
    dataset = tf.data.Dataset.from_tensor_slices(a)
    dataset = dataset.map(
        lambda x: tf.py_function(
            scale,
            [x], 
            (tf.float32, tf.float32, tf.float32))
    )
    
    def set_shape(x, y, z):
        x = tf.reshape(x, [-1, 20, 20, 2])
        y = tf.reshape(y, [1])
        z = tf.reshape(z, [1])
        return x, y, z
    
    dataset = dataset.map(set_shape)
    a, b, c = next(iter(data))
    a.shape, b.shape, c.shape
    (TensorShape([1, 20, 20, 2]), TensorShape([1]), TensorShape([1]))