Search code examples
pythonpyspark

How to read / restore a checkpointed Dataframe - across batches


I need to "checkpoint" certain information during my batch processing with pyspark that are needed in the next batches.

For this use case, DataFrame.checkpoint seems to fit. While I found many places that explain how to create the one, I did not find any how to restore or read a checkpoint.

For this to be tested, I created a simple test class with two (2) tests. The first reads a CSV and creates a sum. The 2nd one should just get some a continue to sum up:

import pytest
from pyspark.sql import functions as f

class TestCheckpoint:

    @pytest.fixture(autouse=True)
    def init_test(self, spark_unit_test_fixture, data_dir, tmp_path):
        self.spark = spark_unit_test_fixture
        self.dir = data_dir("")
        self.checkpoint_dir = tmp_path

    def test_first(self):
        df = (self.spark.read.format("csv")
              .option("pathGlobFilter", "numbers.csv")
              .load(self.dir))

        sum = df.agg(f.sum("_c1").alias("sum"))
        sum.checkpoint()
        assert 1 == 1

    def test_second(self):
        df = (self.spark.read.format("csv")
              .option("pathGlobFilter", "numbers2.csv")
              .load(self.dir))

        sum = # how to get back the sum?

Creating the checkpoint in first test works fine (set tmp_path as checkpoint dir) and i see a folder created with a file.

But how do I read it?

And how do you handle multiple checkpoints? For example, one checkpoint on the sum and another for the average?

Are there better approaches to storing state across batches?

For sake of completeness, the CSV looks like this:

1719228973,1
1719228974,2

And this is only a minimal example to get it running - my real scenario is more complex.


Solution

  • While in theory, checkpoints are retained across Spark jobs and can be accessed from other Spark jobs by reading the files directly without having to recompute the entire lineage, they have not made it easy to read checkpoints directly from the stored files from other Spark jobs. If you are interested, here is an answer where it's shown how the file names look when checkpointing happens.

    So in your case, I would advise storing to disk by yourself and reading the file(s) when you need it in another job. You can use a storage mechanism (like parquet that is efficient depending on your data and nature of processing). Something like this:

    import pytest
    from pyspark.sql import functions as f
    
    class TestCheckpoint:
    
        @pytest.fixture(autouse=True)
        def init_test(self, spark_unit_test_fixture, data_dir, tmp_path):
            self.spark = spark_unit_test_fixture
            self.dir = data_dir("")
            self.checkpoint_dir = tmp_path
    
        def test_first(self):
            df = (self.spark.read.format("csv")
                  .option("pathGlobFilter", "numbers.csv")
                  .load(self.dir))
    
            sum_df = df.agg(f.sum("_c1").alias("sum"))
            sum_df.write.mode("overwrite").parquet(str(self.checkpoint_dir / "sum"))
            assert 1 == 1
    
        def test_second(self):
            previous_sum = self.spark.read.parquet(str(self.checkpoint_dir / "sum"))
            previous_sum_value = previous_sum.collect()[0]["sum"]
    
            df = (self.spark.read.format("csv")
                  .option("pathGlobFilter", "numbers2.csv")
                  .load(self.dir))
    
            new_sum = df.agg(f.sum("_c1").alias("sum"))
            total_sum = previous_sum_value + new_sum.collect()[0]["sum"]
    
            assert 1 == 1
    
    

    That said, if you need to access the checkpointed data within the same Spark job, you can just keep a reference to the dataframe like so

    sum = df.agg(f.sum("_c1").alias("sum"))
    sum = sum.checkpoint() # hold on to this reference to access the checkpointed data
    

    Alternatively, you also have df.persist(StorageLevel.DISK_ONLY) which also allows you to store to disk while also preserving data lineage. However, once the job ends, the data is purged.