Search code examples
pythontensorflowtensorflow2.0tensorflow-datasets

tf.image.stateless_random_crop VS. tf.image.random_crop. Shouldn't these be the same thing?


In tf 2.5, there are two functions for cropping an image: tf.image.stateless_random_crop, and tf.image.random_crop. The documentation states that stateless_random_crop is deterministic (always returns the same crop given one seed). However, random_crop has a seed parameter and is also deterministic, one would think. What is the actual difference between these two functions? I cannot find information about statelessness in Tensorflow anywhere.

The differences between tf.image.stateless_random_crop, and tf.image.random_crop are one line where stateless_random_uniform is used instead of a random_uniform: stateless_random_crop: https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/ops/random_ops.py#L415-L465 random_crop: https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/ops/random_ops.py#L360-L412

I always thought that random_crop would always return the same crop given a seed, but it looks like maybe that wasn't always true? Any enlightenment about statelessness in Tensorflow is greatly appreciated!


Solution

  • random_crop always return the same sequence of crops only when both global seed and operation seed are set.

    1. global seed is set using tf.random.set_seed(global_seed)
    2. operation seed is set by passing the seed argument into the operation, i.e., tf.image.random_crop(value, size, seed=ops_seed)

    whereas what stateless_random_crop returns is totally determined by the seed you pass into it when the device and tensorflow version are unchanged.

    And you are correct that the functions look redundant and duplicate but actually tf.image.random_crop is from the old RNGs API and it may be buggy in graph mode. The new RNGs API is tf.random.Generator and the stateless RNGs. For more information, see https://www.tensorflow.org/guide/random_numbers

    Using tf.random.Generator in combination with stateless_random_crop:

    class new_RNGs_random_crop:
      def __init__(self,seed,size):
        self.rand_generator=tf.random.Generator.from_seed(seed)
        self.size=size
      def random_crop(self,x):
        return tf.image.stateless_random_crop(x,self.size,
               seed=self.rand_generator.uniform_full_int([2],dtype=tf.int32))
    
    dummy_dataset=tf.data.Dataset.from_tensor_slices(np.arange(2*3*3).reshape((2,3,3))).batch(1)
    cropper=new_RNGs_random_crop(88883,(1,2,2))
    dummy_dataset=dummy_dataset.map(cropper.random_crop)
    
    for image in dummy_dataset:
      print(image)
    

    Example outputs:

    tf.Tensor(
    [[[3 4]
      [6 7]]], shape=(1, 2, 2), dtype=int64)
    tf.Tensor(
    [[[ 9 10]
      [12 13]]], shape=(1, 2, 2), dtype=int64)