Search code examples
numpypytorchpython-xarray

How to calculate mean and standard deviation of ~70,000 GeoTIFF files that combined are too big to fit into memory?


I am trying to calculate the mean and standard deviation across ~70,000 3 band TIF files that, when combined, are too big to fit into memory. Ultimately, I want to standardize the dataset, i.e. removing the mean and scaling to unit variance.

Calculating the mean is not an issue as it can be calculated per channel in batches but doing the same leads to large errors in the standard deviation (see discussion https://discuss.pytorch.org/t/about-normalization-using-pre-trained-vgg16-networks/23560?u=kuzand). I have tried using the following to put the data into tensors using PyTorch.

class batchDataset(Dataset):
    def __init__(self,
                 inputs,
                 transform=None
                 ):
        self.inputs = inputs
        self.transform = transform

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self,
                    index: int):
        input_id = self.inputs[index]
        image = np.array(rioxarray.open_rasterio(input_id))

        # Preprocessing
        if self.transform:
            image = self.transform(image)

        image = torch.from_numpy(image)
        return image

def meanAndStd(loader, channels=3):
    """
    Compute the mean and sd in an online fashion
    Var[x] = E[X^2] - E^2[X]

    loader: a custom version of torch.utils.data.DataLoader 
    channels: the number of channels in the image (RGB = 3 / greyscale = 1)
    """
    cnt = 0
    fst_moment = torch.empty(channels)
    snd_moment = torch.empty(channels)

    for data in loader:

        b, c, h, w = data.shape
        nb_pixels = b * h * w
        sum_ = torch.sum(data, dim=[0, 2, 3])
        sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3])
        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)
        cnt += nb_pixels


    return fst_moment, torch.sqrt(snd_moment - (fst_moment ** 2))

if __name__=="__main__":
    datasetType="training"
    input_dir = f"data/IMAGES/{datasetType}/"
    image_directory = f"{input_dir}input_multiband/"
    label_directory = f"{input_dir}label/"

    inputImages = sorted(glob(f"{image_directory}/*.tif"))
    labelImages = sorted(glob(f"{label_directory}/*.tif"))
    lenInputImages,lenLabelImages = len(inputImages),len(labelImages)
    assert lenInputImages == lenLabelImages

    input_dataset = batchDataset(inputs=inputImages)
    dataloader = DataLoader(input_dataset,batch_size=1, shuffle=False, num_workers=4)
    input_mean, input_std = meanAndStd(dataloader)

But this yields large errors in the standard deviation. I know that xarray and dask can be used for this kind of thing so I tried the below but I think xarray is overcomplicating it by combining by the latitude and longitude of the TIFs which I don't want as they overlap.

ds = xarray.open_mfdataset(inputImages,engine='rasterio', parallel=True,)

gives a value error

ValueError: cannot reindex or align along dimension 'x' because the (pandas) index has duplicate values

which persists even when I tried the below (seen https://github.com/pydata/xarray/discussions/6297)

def drop_duplicates_along_all_dims(obj, keep="first"):
    deduplicated = obj
    for dim in obj.dims:
        indexes = {dim: ~deduplicated.get_index(dim).duplicated(keep=keep)}
        deduplicated = deduplicated.isel(indexes)
    return deduplicated

ds = xarray.open_mfdataset(inputImages,engine='rasterio', parallel=True, preprocess=drop_duplicates_along_all_dims)

Is there a way to do this efficiently using xarray? Can I somehow ignore or reset the coords in the preprocessing step passed to the open_mfdataset function?

I would really appreciate any help! Thanks!


Solution

  • Simply make one iteration through the data to obtain the mean, then make another pass through the data obtaining the running per-file variance (data-mean)**2 sum, then use this to compute standard deviation at the end. You'll only have trouble if you need to hold all the data in memory until you compute the mean; two passes avoids this issue.