Search code examples
pythonmeanupdatesvarianceonline-algorithm

Efficient algorithm for online Variance over image batches


I have a large amount of images and want to calculate the variance (of each channel) across all of them. I am having the problem of finding an efficient (and even correct) algorithm for this.

I found the Welford's online algorithm, but it is way too slow as it does not vectorize across a single image or a batch of images.

How to improve the speed of it by using vectorization or making use of inbuilt variance algorithms?


Solution

  • These are the two functions needed to update/combine the mean and variances of two batches. Both functions can be used with vectors (the 3 color channels) and the mean and variance can be acquired from inbuilt methods like batch.var().

    Equations taken from: https://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html

    # m amount of samples (or pixels) over all previous badges
    # n amount of samples in new incoming batch
    # mu1 previous mean
    # mu2 mean of current batch
    # v1 previous variance
    # v2 variance of current batch
    
    def combine_means(mu1, mu2, m, n):
        """
        Updates old mean mu1 from m samples with mean mu2 of n samples.
        Returns the mean of the m+n samples.
        """
        return (m/(m+n))*mu1 + (n/(m+n))*mu2
    
    def combine_vars(v1, v2, mu1, mu2, m, n):
        """
        Updates old variance v1 from m samples with variance v2 of n samples.
        Returns the variance of the m+n samples.
        """
        return (m/(m+n))*v1 + n/(m+n)*v2 + m*n/(m+n)**2 * (mu1 - mu2)**2
        
    

    As you see one can simplify them a bit by reusing some calculations like m+n but keeping it in this pure form for better understanding.