Search code examples
pythonparquetpyarrowapache-arrow

Selecting deep columns in pyarrow.dataset parquet


Let's say I have a deeply nested arrow table like:

pyarrow.Table
arr: struct<arr: struct<a: list<item: int64 not null> not null, b: list<item: int64 not null> not null> not null>
  child 0, arr: struct<a: list<item: int64 not null> not null, b: list<item: int64 not null> not null> not null
      child 0, a: list<item: int64 not null> not null
          child 0, item: int64 not null
      child 1, b: list<item: int64 not null> not null
          child 0, item: int64 not null
----
arr: [
  -- is_valid: all not null
  -- child 0 type: struct<a: list<item: int64 not null> not null, b: list<item: int64 not null> not null>
    -- is_valid: all not null
    -- child 0 type: list<item: int64 not null>
[[1,2,3],[1,2,3],[1,2,3],[1,2,3]]
    -- child 1 type: list<item: int64 not null>
[[3,4,5],[3,4,5],[3,4,5],[3,4,5]]]

I can write this to a parquet dataset with pyarrow.dataset.write_dataset.

With the now deprecated pyarrow.parquet module, I could choose to read a selection of one or more of the leaf nodes like this:

pf = pa.ParquetDataset("temp.parq/")
pf.read(columns=["arr.arr.a.list.item"])

How do I achieve this with the pyarrow.dataset API? As far as I can tell, I can only select top-level fields, in this case ["arr"], which would get me both leaf nodes, not just one.

Whilst .to_table(columns=) is promising, it loses the original record structure of the data, so that if I needed to pick more than one leaf out of many, they would all be independent:

cf

> ds.to_table().to_pydict()
{'arr': [{'arr': {'a': [1, 2, 3], 'b': [3, 4, 5]}}, ...

> ds.to_table(columns={'leaf': pyarrow.dataset.field('arr', 'arr', 'a')}).to_pydict()
{'leaf': [[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]}

leaf: list<item: int64 not null> not null
    child 0, item: int64 not null

but it should be

{'arr': [{'arr': {'a': [1, 2, 3]}}, ..

arr: struct<arr: struct<a: list<item: int64 not null> not null, b: list<item: int64 not null> not null> not null>
  child 0, arr: struct<a: list<item: int64 not null> not null, b: list<item: int64 not null> not null> not null
    child 0, a: list<item: int64 not null> not null
      child 0, item: int64 not null

Solution

  • Your schema is the following:

    import pyarrow as pa
    
    schema = pa.schema(
        [
            pa.field(
                "arr",
                pa.struct(
                    [
                        pa.field(
                            "arr",
                            pa.struct(
                                [
                                    pa.field(
                                        "a",
                                        pa.list_(pa.int64())),
                                    pa.field(
                                        "b",
                                        pa.list_(pa.int64()))
                                ]
                            ))
                    ]
                ))
        ])
    
    

    and you want to remove b so the sub schema is this one:

    sub_schema = pa.schema(
        [
            pa.field(
                "arr",
                pa.struct(
                    [
                        pa.field("arr",
                                 pa.struct(
                                     [
                                         pa.field("a", pa.list_(pa.int64())),
                                     ]
                                 ))
                    ]
                ))
        ])
    
    

    I can't think of a way to achieve is with dataset. to_table uses from_dataset which would flatten the extracted field.

    Also it doesn't look like changing the schema of the dataset works: ds.replace_schema(sub_schema) throws an ArrowTypeError

    But you can load the dataset in a table and cast that table, which works:

    table = ds.to_table()
    table.cast(sub_schema)
    

    Another option is to provide the sub_schema when loading the dataset:

    pyarrow.dataset.dataset('./my_ds/',  schema=sub_schema).to_table()