Search code examples
pythonkerashdf5tf.kerash5py

Splitting up `h5` file and combining the pieces back


I have an h5 file, which is basically model weights output by keras. For some storage requirements, I'd like to split up the large h5 file into smaller pieces, and combine them back into a single file when needed. However, the way I do it seems to miss some "metadata" (not sure, maybe it's missing a lot more, but judging by the size of the combined file and the original file, it seems that I'm not missing much).

Here's my splitting script:

prefix = "model_weights"
fname_src = "DiffusiveSizeFactorAI/model_weights.h5"
size_max = 90 * 1024**2  # maximum size allowed in bytes
is_file_open = False
dest_fnames = []
idx = 0

with h5py.File(fname_src, "r") as src:
    for group in src:
        fname = f"{prefix}_{idx}.h5"
        if not is_file_open:
            dest = h5py.File(fname, "w")
            dest_fnames.append(fname)
            is_file_open = True
        group_id = dest.require_group(group)
        src.copy(f"/{group}", group_id)
        size = os.path.getsize(fname)
        if size > size_max:
            dest.close()
            idx += 1
            is_file_open = False
    dest.close()

and here's the script that I use for combining back the pieces:

fname_combined = f"{prefix}_combined.h5"

with h5py.File(fname_combined, "w") as combined:
    for fname in dest_fnames:
        with h5py.File(fname, "r") as src:
            for group in src:
                group_id = combined.require_group(group)
                src.copy(f"/{group}", group_id)

Just to add a little bit of context if it helps debugging my case, when I load the "combined" model weights, here's the error I'm getting:

ValueError: Layer count mismatch when loading weights from file. Model expected 108 layers, found 0 saved layers.

Note: the size of the original file and the combined one are about the same (they differ by less than 0.5%), which is why I think that I might be missing some metadata.


Solution

  • Based on an answer from h5py developers, there are two issues:

    1. Every time an h5 file is copied this way, a duplicate extra folder level will be added to the destination file. Let me explain:

    Suppose in src.h5, I have the following structure: /A/B/C. In these two lines:

    group_id = dest.require_group(group)
    src.copy(f"/{group}", group_id)
    

    group is /A, and so, after copying, an extra /A will be added to dest.h5, which results in the following erroneous struction: /A/A/B/C. To fix that, one needs to explicitly pass name="A" as an argument to copy.

    1. Metadata of the root level "/" is not being copied neither in the splitting nor in the combining script. To fix that, given that h5 data structure is very similar to Python's dict, you just need to add:
    dest.attrs.update(src.attrs)
    

    For personal use, I've written two helper functions, one that splits up a large h5 file into smaller parts, each not exceeding a specified size (passed as argument by user), and another one that combines them back into a single h5 file. In case you find it useful, it can be found on Github here.