Search code examples
pythonapache-sparkpysparkparquet

Reading single parquet-partition with single file results in DataFrame with more partitions


Context

I have a Parquet-table stored in HDFS with two partitions, whereby each partition yields only one file.

parquet_table \
    | year=2020 \
        file_1.snappy.parquet
    | year=2021 \
        file_2.snappy.parquet

My plan is to only grap the latest partition and work on that.

df = spark.read.parquet("hdfs_path_to_table/parquet_table/year=2021/")

This works, I only retrieve the required data. While I wrote that for pySpark I assume that pure Spark will be somehow analog.

Problem

Despite the fact that I retrieve the correct data, Spark still has two partitions connected to the DataFrame df:

df.rdd.getNumPartitions()
# -> 2

When I count the contents inside the partitions, I see that only one yields data:

df.rdd.mapPartitions(lambda partition: [len([row for row in partition])]).collect()
# -> [1450220, 0]

Of course I can now easily do a df.coalesce(1) and end up with the desired result. Anyhow, I wonder why this happens and I'd actually rather do not want to have to coalesce but directly only retrieve the partition.

Question

Is there any solution how my DataFrame df will only have the corresponding correct .getNumPartitions()? Thus, is there a way to load a single parquet-file and yield this file in a single partition?


Solution

  • One of the issues is that partition is an overloaded term in Spark world and you're looking at 2 different kind of partitions:

    • your dataset is organized as a Hive-partitioned table, where each partition is a separate directory named with <partition_attriute>=<partiton_value> that may contain many data files inside. This is only useful for dynamically pruning the set of input files to read and has no effect on the actual RDD processing

    • when Spark loads your data and creates a DataFrame/RDD, this RDD is organized in splits that can be processed in parallel and that are also called partitions.

    df.rdd.getNumPartitions() returns the number of splits in your data and that is completely unrelated to your input table partitioning. It's determined by a number of config options but is mostly driven by 3 factors:

    • computing parallelism: spark.default.parallelism in particular is the reason why you have 2 partitions in your RDD even though you don't have enough data to fill the first
    • input size: spark will try to not create partitions bigger than spark.sql.files.maxPartitionBytes and thus may split a single multi-gigabyte parquet file into many partitions)
    • shuffling: any operation that need to reorganize data for correct behavior (for example join or groupBy) will repartition your RDD with a new strategy and you will end up with many more partitions (governed by spark.sql.shuffle.partitions and AQE settings)

    On the whole, you want to preserve this behavior since it's necessary for Spark to process your data in parallel and achieve good performance. When you use df.coalesce(1) you will coalesce your data into a single RDD partition but you will do your processing on a single core in which case simply doing your work in Pandas and/or Pyarrow would be much faster.

    If what you want is to preserve the property on your output to have a single parquet file per Hive-partition attribute, you can use the following construct:

    # Read your partitioned dataset and filter on your preferred partition(s)
    df = spark.read.parquet("hdfs_path_to_table/parquet_table/").filter("year = 2021")
    
    # do your work
    df_output = df.<do_something>
    
    # repartition impacts how Spark organize the data in RDD splits
    df_repartitioned = df_output.repartition("<partition attribute>")
    
    # PartitionedBy impacts how Spark organizes data on disk in separate folders
    df_repartitioned.write.mode("overwrite").partitionedBy("<partition_attribute>").parquet("hdfs_output")
    

    If you process some of your partitions and don't want to overwrite the complete output every time, be sure to set spark.sql.sources.partitionOverwriteMode=dynamic to only overwrite the affected Hive partitions.