Search code examples
python-3.xlambdatensorflowinitializationinitializer

TensorFlow layers: using custom(ized) initialization function?


Why does obtaining a new initialization function with partial give me an error, while a lambda doesn't?

All of these functions:

f_init = partial(tf.random_normal, mean=0.0, stddev=0.01, partition_info=None)
f_init = partial(tf.contrib.layers.xavier_initializer, partition_info=None)
f_init = partial(tf.random_normal, mean=0.0, stddev=0.01)
f_init = tf.contrib.layers.xavier_initializer

Throw the following exception:

TypeError: ... got an unexpected keyword argument 'partition_info'

(while ... stands for xavier_initializer and the other functions, of course)

When applied to a simple conv2d layer:

conv1 = tf.layers.conv2d(x, 32, [5, 5],
                         strides=[1, 1],
                         padding="same",
                         activation=tf.nn.relu,
                         kernel_initializer=f_init,
                         name="conv1")

However, if I use a lambda to obtain custom initialization functions:

f_init = lambda shape, dtype, partition_info=None:\
         tf.random_normal(shape, mean=0.0, stddev=0.01, dtype=dtype)

...it works without any problems.

Shouldn't partial also return a new anonymous function of, e.g., tf.random_normal supplied with mean=0.0 and stddev=0.01 like the lambda statement does?


Solution

  • The error says that the functions tf.random_normal and tf.contrib.layers.xavier_initializer do not have an parameter with the name partition_info which is indeed the case. There is no such parameter (see here and here).

    Your lambda works, because it does not pass the partition_info to tf.random_normal, which is correct.

    Also make sure to not get confused with the functions returning initialisation values (like tf.random_normal) and the corresponding initializer (like tf.random_normal_initializer). The first one returns floats, the latter creates a callable, that expects a shape, a dtype and the partition_info. When called, this callable returns the normal distributed values.

    Your lambda does conform to this signature and thus it works. But when using partial the signature of the resulting callable is just the list of parameters that haven't been frozen by the call to partial:

    f_init = partial(tf.random_normal, mean=0.0, stddev=0.01)
    

    Since tf.random_normalhas the signature:

    def random_normal(shape, mean=0.0, stddev=1.0, dtype=dtypes.float32,
            seed=None, name=None):
        # ...
    

    You can use the partial as if it was defined like this:

    def f_init(shape, dtype=dtypes.float32, seed=None, name=None):
        # ...
    

    Note that there is no parameter named partition_info, but TensorFlow will try to pass it when calling f_init, resulting in the error you got.

    To customize things like mean and stddev, you do not need to create a custom initializer, though. This for example creates an initializer, that returns normal distributed values with mean 0.0 and standard deviation 0.01:

    f_init = tf.random_normal_initializer(mean=0.0, stddev=0.01)
    

    But if you need a custom initializer, e.g. to implement custom initialization logic, you could follow this pattern (see here):

    class RandomNormal(Initializer):
        def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32):
            self.mean = mean
            self.stddev = stddev
            self.seed = seed
            self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
    
        def __call__(self, shape, dtype=None, partition_info=None):
            if dtype is None:
            dtype = self.dtype
            normal = random_ops.random_normal(shape, self.mean, self.stddev,
                dtype, seed=self.seed)
            # do what you want with normal here
            return normal
    
        def get_config(self):
            return {"mean": self.mean,
                "stddev": self.stddev,
                "seed": self.seed,
                "dtype": self.dtype.name}
    
    # Alias to lower_case, 'function-style' name
    random_normal = RandomNormal