Search code examples
pythontensorflownormal-distributionnoise

Corrupting Tensor with Gaussian Noise based on a Variance / Stddev Matrix


I have a CNN output mask (logits, shape = [batch, x, y, classes]) and a matrix (sigma) that assigns a stddev for Gaussian Noise to each logit (i.e. sigma.shape = [x, y, classes]). I want to corrupt each logit with Gaussian noise based on it's corresponding sigma. In tensorflow, I only found a Gaussian that works on scalars: tf.random_normal.

Hence I used loops, "computed" the noise for each logit separately (mean = logit[b, x, y, c], stddev = sigma[x, y, c]) and used tf.stack to get back my 4-D Tensor. BUT: For a [1, 1024, 1024, 2] -Tensor this already takes ages (to the extent of "didn't finish"), which makes sense, I guess since it has to create and stack > 1 million tensor objects. Anyway, I am pretty sure this is not the way to go...

But how should I do this? Is there a tf.random_normal that allows working in higher dimensions?

I know that tf.random_normal can return a more-dimensional Tensor with (arbitrary) shape, however that doesn't work for me since that applies the same stddev to each element (the mean doesn't matter since I could generate 0-mean noise and tf.add()).

If it's relevant in any way: A compromise I could live with for now (to speed things up) would be to generate noise with stddev based on pixels only (not on class), i.e. *sigma.shape = [x, y]*. But that takes away only one loop and not the main problem (x*y).

Here the code for the "loop" approach. I tested it for small values ([1, 8, 8, 2]. I know I could leave out the for b-loop, but it's not the "meat" cost, the x*y computation is the real problem:

logits = self.output_mask   # e.g. shape: [1, 1024, 1024, 2]
sigma = tf.ones(shape=[1024, 1024, 2], dtype=logits.dtype)  # dummy values, will be a learned parameter later

corrupted_logits_b = []
for b in range(logits.shape[0]):
    corrupted_logits_x = []
    for x in range(logits.shape[1]):
        corrupted_logits_y = []
        for y in range(logits.shape[2]):
            corrupted_logits_c = []
            for c in range(logits.shape[3]):
                # this is where the noise is computed
                # (added to the logit since mean = logit)
                corrupted_logit = tf.random_normal(tf.shape(logits[b, x, y, c]), 
                                                            mean=logits[b, x, y, c], 
                                                            stddev=sigma_val[x, y, c])
                # "values"/logit-tensors are appended to lists and
                # ... and stacked to form higher-dim tensors
                corrupted_logits_c.append(corrupted_logit)
            corrupted_logits_y.append(tf.stack(corrupted_logits_c, axis=-1))
        corrupted_logits_x.append(tf.stack(corrupted_logits_y))
    corrupted_logits_b.append(tf.stack(corrupted_logits_x))
corrupted_logits = tf.stack(corrupted_logits_b, axis=-1)

Solution

  • So I actually found a good workaround using the circumstance that a normal distribution N(0, 1) multiplied by σ (sigma) is the same as a normal distribution with stddev σ (sigma) N(0, σ**2):

    N(0,σ**2) = N(0,1)⋅σ
    

    Hence you can do:

    # sigma can be of shape of logits or needs to be broadcastable
    sigma = tf.Placeholder(shape=tf.shape(logits), dtype=logits.dtype)
    def corrupt_logits(logits, sigma):
        # generate a normal distribution N(0,1) of desired shape (4D in my case)
        gaussian = tf.random_normal(tf.shape(logits), mean=0.0, stddev=1.0, dtype=logits.dtype)
        # turn into normal distribution N(0,σ**2) by multiplying with σ
        # since sigma is a 4D tensor of tf.shape(logits), it can "assign" an individual σ to each logit
        noise = tf.multiply(gaussian, sigma)
        # add zero-mean/stddev-σ noise to logits
        return tf.add(logits, noise)
    

    This will be a lot better than doing it with for-loops and supports various variants since if sigma is the same for a certain dimension (e.g. only different for each x, y), it is automatically broadcasted over the other dimensions (e.g. batch_size, classes).