Search code examples
machine-learning

How to calculate correlation of colours in a dataset?


In this Distill article (https://distill.pub/2017/feature-visualization/) in footnote 8 authors write:

The Fourier transforms decorrelates spatially, but a correlation will still exist 
between colors. To address this, we explicitly measure the correlation between colors
 in the training set and use a Cholesky decomposition to decorrelate them.

I have trouble understanding how to do that. I understand that for an arbitrary image I can calculate a correlation matrix by interpreting the image's shape as [channels, width*height] instead of [channels, height, width]. But how to take the whole dataset into account? It can be averaged over, but that doesn't have anything to do with Cholesky decomposition.

Inspecting the code confuses me even more (https://github.com/tensorflow/lucid/blob/master/lucid/optvis/param/color.py#L24). There's no code for calculating correlations, but there's a hard-coded version of the matrix (and the decorrelation happens by matrix multiplication with this matrix). The matrix is named color_correlation_svd_sqrt, which has svd inside of it, and SVD wasn't mentioned anywhere else. Also the matrix there is non-triangular, which means that it hasn't come from the Cholesky decomposition.

Clarifications on any points I've mentioned would be greatly appreciated.


Solution

  • I figured out the answer to your question here: How to calculate the 3x3 covariance matrix for RGB values across an image dataset?

    In short, you calculate the RGB covariance matrix for the image dataset and then do the following calculations

    U,S,V = torch.svd(dataset_rgb_cov_matrix)
    epsilon = 1e-10
    svd_sqrt = U @ torch.diag(torch.sqrt(S + epsilon))