Search code examples
pyspark

How to efficiently create a PySpark dataframe with only data from some hive-style partitions?


My data is stored in parquet files with hive-style partitioning with two partition keys source_id and year_month, similar to the following structure:

/data/plandata/
├── source_id=1/
│   ├── year_month=2023-08-01/
│   │   ├── file1.parquet
│   │   └── ...
│   ├── year_month=2023-09-01/
│   │   ├── file2.parquet
│   │   └── ...
├── source_id=2/
│   ├── year_month=2023-09-01/
│   │   ├── file3.parquet
│   │   └── ...
│   ├── year_month=2023-10-01/
│   │   ├── file4.parquet
│   │   └── ...
└── ...

I want to create a PySpark dataframe containing only the most recent year_month data for each source_id. The most recent year_month may be different for different source_id values (and will change over multiple runs of my code) so this must be dynamic.

An approach that works is to:

  1. first create a dataframe with just the source_id and year_month and group by source_id taking max year_month
  2. then use this in a join on all of the data
from pyspark.sql import functions as F
basepath = '.../test_data'

# prepare test data - just for example, in practice my data already exists
data = [
    {'source_id': 1, 'year_month': '2023-08-01', 'dataval': 'a-01-08'},
    {'source_id': 1, 'year_month': '2023-08-01', 'dataval': 'b-01-08'},
    {'source_id': 1, 'year_month': '2023-08-01', 'dataval': 'c-01-08'},
    {'source_id': 1, 'year_month': '2023-09-01', 'dataval': 'a-01-09'},
    {'source_id': 1, 'year_month': '2023-09-01', 'dataval': 'b-01-09'},
    {'source_id': 1, 'year_month': '2023-09-01', 'dataval': 'c-01-09'},
    {'source_id': 2, 'year_month': '2023-09-01', 'dataval': 'a-02-09'},
    {'source_id': 2, 'year_month': '2023-09-01', 'dataval': 'b-02-09'},
    {'source_id': 2, 'year_month': '2023-09-01', 'dataval': 'c-02-09'},
    {'source_id': 2, 'year_month': '2023-10-01', 'dataval': 'a-02-10'},
    {'source_id': 2, 'year_month': '2023-10-01', 'dataval': 'b-02-10'},
    {'source_id': 2, 'year_month': '2023-10-01', 'dataval': 'c-02-10'},
    {'source_id': 3, 'year_month': '2023-08-01', 'dataval': 'a-03-08'},
    {'source_id': 3, 'year_month': '2023-08-01', 'dataval': 'b-03-08'},
    {'source_id': 3, 'year_month': '2023-08-01', 'dataval': 'c-03-08'},
    {'source_id': 3, 'year_month': '2023-10-01', 'dataval': 'a-03-10'},
    {'source_id': 3, 'year_month': '2023-10-01', 'dataval': 'b-03-10'},
    {'source_id': 3, 'year_month': '2023-10-01', 'dataval': 'c-03-10'},
]
df = spark.createDataFrame(data)
df.write.mode('overwrite').partitionBy('source_id', 'year_month').parquet(basepath)


# reading in the data - this is the bit I need to figure out
df_recent_partitions = spark.read.parquet(basepath).select('source_id', 'year_month')
df_recent_partitions = df_recent_partitions.groupBy('source_id'
   ).agg(F.max('year_month').alias('max_year_month')
)
df_latest_data = spark.read.parquet(basepath)
df_latest_data = df_latest_data.join(df_recent_partitions, 
   (
      (df_latest_data.source_id == df_recent_partitions.source_id) 
      & (df_latest_data.year_month == df_recent_partitions.max_year_month)
   )
).select(df_latest_data['*'])

df_latest_data.show()

+-------+---------+----------+
|dataval|source_id|year_month|
+-------+---------+----------+
|a-02-10|        2|2023-10-01|
|b-02-10|        2|2023-10-01|
|c-02-10|        2|2023-10-01|
|a-03-10|        3|2023-10-01|
|b-03-10|        3|2023-10-01|
|c-03-10|        3|2023-10-01|
|a-01-09|        1|2023-09-01|
|b-01-09|        1|2023-09-01|
|c-01-09|        1|2023-09-01|
+-------+---------+----------+

So this works as expected, for each source_id I have only the most recent year_month data.

My questions:

  1. Is this an efficient way to achieve my desired result?
    I'm assuming that a) in creating df_recent_partitions Spark is not reading in data from all the files (since I am only selecting source_id and year_month which can be determined from just the folder names) and b) that in creating df_latest_data it is smart enough to only read data from the files in the relevant folders based on the join, and skip the others. But I don't know if this is what's actually happening.
  2. Is there a way to validate my assumptions? I generated the explain plans (see below) for df_recent_partitions and df_latest_data but don't know how to interpret them.
df_recent_partitions.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[source_id#39], functions=[max(year_month#40)])
   +- Exchange hashpartitioning(source_id#39, 200), ENSURE_REQUIREMENTS, [plan_id=449]
      +- HashAggregate(keys=[source_id#39], functions=[partial_max(year_month#40)])
         +- FileScan parquet [source_id#39,year_month#40] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/content/drive/MyDrive/PySpark exploration/test_data], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<>

#----------------------------------

df_latest_data.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [dataval#127, source_id#128, year_month#129]
   +- BroadcastHashJoin [source_id#128, year_month#129], [source_id#39, max_year_month#52], Inner, BuildRight, false
      :- FileScan parquet [dataval#127,source_id#128,year_month#129] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/content/drive/MyDrive/PySpark exploration/test_data], PartitionFilters: [isnotnull(source_id#128), isnotnull(year_month#129)], PushedFilters: [], ReadSchema: struct<dataval:string>
      +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, int, true], input[1, date, false]),false), [plan_id=435]
         +- Filter isnotnull(max_year_month#52)
            +- HashAggregate(keys=[source_id#39], functions=[max(year_month#40)])
               +- Exchange hashpartitioning(source_id#39, 200), ENSURE_REQUIREMENTS, [plan_id=431]
                  +- HashAggregate(keys=[source_id#39], functions=[partial_max(year_month#40)])
                     +- FileScan parquet [source_id#39,year_month#40] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/content/drive/MyDrive/PySpark exploration/test_data], PartitionFilters: [isnotnull(source_id#39)], PushedFilters: [], ReadSchema: struct<>

Solution

  • As I don't have your dataset, I use my own dataset to mimick:

    root
     |-- id: string (nullable = true)
     |-- ts: long (nullable = true)
     |-- year: integer (nullable = true)
     |-- month: integer (nullable = true)
     |-- day: integer (nullable = true)
    

    This is my dataset, ts means unix timestamp. year, month and day are the corresponding datetime information. The dataset is partitioned by year , month and day column respectively. If I try to do the similar logic like yours:

    df = spark.read.parquet(path)
    df_agg = df.groupBy('year').agg(func.max('month').alias('max_month'))
    df_agg.explain()
    
    == Physical Plan ==
    AdaptiveSparkPlan isFinalPlan=false
    +- HashAggregate(keys=[year#2], functions=[max(month#3)])
       +- Exchange hashpartitioning(year#2, 200), ENSURE_REQUIREMENTS, [plan_id=13]
          +- HashAggregate(keys=[year#2], functions=[partial_max(month#3)])
             +- Project [year#2, month#3]
                +- FileScan parquet [year#2,month#3,day#4] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[mypath, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<>
    

    To answer your question:

    1. Even you indicate your column selection, it just triggers Projection Pushdown to minimize columns that you have to select. It's inevitable that you have to scan the underlying parquet files to collect columns information, even you're reading the partitioning columns. You can validate it by checking your Spark history sever to see the query DAG. You should see columns and size of the files that your Spark Application scans, as your are using parquet as storage. Or you should see a Projection Node in your query plan Project [year#2, month#3] (not sure why you don't have it in your query plan, it depends the Spark version and platform you use).
    2. I don't think it's a efficient way to achieve your goal. As in df_latest_data plan, you can see that your application involves 2 file scanning and 1 joining process. Even it automatically broadcasts your dataframe, it's still a heavy work load. To achieve your goal, as your just want to do some searching based on the path string, why don't just use glob module in file system or use Spark Context to get the file status in HDFS? It should be much much faster as it doesn't require any file scanning or file transferring.

    -----

    Edit 1 on 2023-11-14

    where clause or filter will trigger the predicate pushdown to filter out the selected rows only, while select(cols) will trigger the project pushdown to return the selected columns only.

    Based on your approach and logic plan, your procedure should be:

    1. Read all the partitions and select the source_id and year_month column only (trigger the project pushdown).
    2. Calculate the max value of year_month based on each source_id
    3. Read all the partitions again
    4. Join two dataframes to return the desired output

    In your step 1, you use df_recent_partitions = spark.read.parquet(basepath).select('source_id', 'year_month') to read all the pair of source_id + year_month combination, which means you don't have any predicate pushdown and scan all the underlying parquet.

    Only when you call df_recent_partitions = spark.read.option('basePath', basepath).parquet(basepath+"/source_id=2").select('source_id', 'year_month') or df_recent_partitions = spark.read.option('basePath', basepath).parquet(basepath+"/source_id=2").select('source_id', 'year_month'), it will trigger the predicate pushdown to filter and select the only row that you need.