Search code examples
pythonarrayspytorchzarr

Creating a generator over a zarr array with start and end for pytorch dataloader


I'm working on a pytorch project where my data is saved in zarr.

Random access on zarr is costly, but thanks to zarr using a blockwise cache, iteration is really quick. To harness this fact, I use an IterableDataset together with multiple workers:

class Data(IterableDataset):
    def __init__(self, path, start=None, end=None):
        super(Data, self).__init__()
        store = zarr.DirectoryStore(path)
        self.array = zarr.open(store, mode='r')

        if start is None:
            start = 0
        if end is None:
            end = self.array.shape[0]

        assert end > start

        self.start = start
        self.end = end

    def __iter__(self):
        return islice(self.array, self.start, self.end)

The issue is that self.array has on the order of 10e9 rows and for consecutive workers, as self.start and self.end naturally get bigger, creating the generators like itertools.islice(array, start, end) takes a significant time out of my training/validation processes, because islice still has to iterate over the unneeded elements until it gets to start. Once a generator is created per each worker, this works like a charm, but to get there takes too long.

Is there a better way to create such a generator? Or maybe there's a smarter way to use zarr in pytorch?


Solution

  • Update: As of March 2021 this change has been merged into zarr.

    I took a small dive into zarr and it looks like this will most easily be enabled from inside zarr. I have opened an issue here, in the meantime I made a fork of zarr that implements the function array.islice(start, end).

    The dataset __iter__ method then looks like this:

    def __iter__(self):
        return self.array.islice(self.start, self.end)