Search code examples
pythongoogle-cloud-platformpyarrowapache-arrow

When constructing a pyarrow dataset with `FileSystemDataset.from_paths`, how do I specify partitions?


I have a list of CSV files in a google bucket organized like gs://bucket/some_dir/{partition_value}/filename. I want to create a pyarrow.Dataset from a list of URIs like this (which is a subset of files in some_dir).

How do I do this and extract partition_value as a column?

So far I have:

import gcsfs
import pyarrow as pa
import pyarrow.csv
import pyarrow.dataset as ds
from pyarrow.fs import FSSpecHandler, PyFileSystem

fs = gcsfs.GCSFileSystem()
schema = pa.schema([("gene_id", pa.string()), ("raw_count", pa.float32()), ("scaled_estimate", pa.float32())])

# these data are publicly accessible, btw
uris = [
    "gs://gdc-tcga-phs000178-open/0b8b258e-1671-4f86-82e7-59b12ad40d9c/unc.edu.4c243ea9-dfe1-42f0-a887-3c901fb38542.2477720.rsem.genes.results",
    "gs://gdc-tcga-phs000178-open/c8ee8367-c529-4dd6-98b4-fde57991134b/unc.edu.a64ae1f5-a189-4173-be13-903bd7637869.2476757.rsem.genes.results",
    "gs://gdc-tcga-phs000178-open/78354f8d-5ce8-4617-bba4-79614f232e97/unc.edu.ac19f7cf-670b-4dcc-a26b-db0f56377231.2509607.rsem.genes.results",
]

dataset = ds.FileSystemDataset.from_paths(
    uris,
    schema,
    format=ds.CsvFileFormat(parse_options=pa.csv.ParseOptions(delimiter="\t")),
    filesystem=PyFileSystem(FSSpecHandler(fs)),
    # partitions=["bucket", "file_gcs_id"],
    # root_partition="gdc-tcga-phs000178-open",
)

dataset.to_table()

this gives me a nice table with fields in my schema.

but, i'd like partition_key to be another field in my dataset. i'm guessing i need:

  1. to add this as a field to my schema and
  2. to add something when calling FileSystemDataset.from_paths

i tried fiddling with root_partition, but got an error that the string i provided isn't a pyarrow.Expression (no idea what that is). also i tried specifying partitions but i get ValueError: The number of files resulting from paths_or_selector must be equal to the number of partitions.


Solution

  • During dataset discovery filename information is used (along with a specified partitioning) to generate "guarantees" which are attached to fragments. For example, when we see the file foo/x=7/bar.parquet and we are using "hive partitioning" we can attach the guarantee x == 7. These guarantees are stored as "expressions" for various reasons we don't need to discuss at the moment.

    Two solutions jump to mind. First, you could create the guarantees yourself and attach them to your paths (this is what the partitions argument represents in the from_paths method). The expression should be ds.field("column_name") == value.

    Second, you could allow the dataset discovery process to run as normal. This will generate all the fragments you need (and some you don't), with the guarantees already attached. Then you could trim down the list of fragments to your desired list of fragments and create a dataset from that.

    (I'm guessing I need) to add this as a field to my schema

    Yes. In both of the above approaches you will want to make sure your partitioning column(s) are added to your schema.

    Here is a code example showing both approaches:

    import shutil
    
    import pyarrow as pa
    import pyarrow.dataset as ds
    import pyarrow.fs as fs
    
    shutil.rmtree('my_dataset', ignore_errors=True)
    
    table = pa.Table.from_pydict({
        'x': [1, 2, 3, 4, 5, 6],
        'part': ['a', 'a', 'a', 'b', 'b', 'b']
        })
    
    ds.write_dataset(table, 'my_dataset', partitioning=['part'], format='parquet')
    
    print('# Created by dataset factory')
    partitioning = ds.partitioning(schema=pa.schema([pa.field('part', pa.string())]))
    dataset = ds.dataset('my_dataset',partitioning=partitioning)
    print(dataset.to_table())
    print()
    
    desired_paths = [
        'my_dataset/a/part-0.parquet'
    ]
    
    # Note that table.schema used below includes the partitioning
    # column so we've added that to the schema.
    print('# Created from paths')
    filesystem = fs.LocalFileSystem()
    dataset_from_paths = ds.FileSystemDataset.from_paths(
        desired_paths,
        table.schema,
        format=ds.ParquetFileFormat(),
        filesystem=filesystem)
    print(dataset_from_paths.to_table())
    print()
    
    print('# Created from paths with explicit partition information')
    dataset_from_paths = ds.FileSystemDataset.from_paths(
        desired_paths,
        table.schema,
        partitions=[
            ds.field('part') == "a"
        ],
        format=ds.ParquetFileFormat(),
        filesystem=filesystem)
    print(dataset_from_paths.to_table())
    print()
    
    print('# Created from discovery then trimmed')
    trimmed_fragments = [frag for frag in dataset.get_fragments() if frag.path in desired_paths]
    trimmed_dataset = ds.FileSystemDataset(trimmed_fragments, dataset.schema, dataset.format, filesystem=dataset.filesystem)
    print(trimmed_dataset.to_table())