Search code examples
apache-sparkpalantir-foundryfoundry-code-repositories

How do I compute my Foundry 'latest version' dataset faster?


I have a dataset ingesting the latest edits to rows of my data, but it only ingests the recently edited version. (i.e. it's incremental on an update_ts timestamp column).

Original table:

| primary_key | update_ts |
|-------------|-----------|
| key_1       | 0         |
| key_2       | 0         |
| key_3       | 0         |

Table as it gets updated:

| primary_key | update_ts |
|-------------|-----------|
| key_1       | 0         |
| key_2       | 0         |
| key_3       | 0         |
| key_1       | 1         |
| key_2       | 1         |
| key_1       | 2         |

After ingestion, I need to compute the 'latest version' for all prior updates while also taking into account any new edits.

This means I am taking the incremental ingest and running a SNAPSHOT output each time. This is very slow for my build as I've noticed I have to look over all my output rows every time I want to compute the latest version for my data.

Transaction n=1 (SNAPSHOT):

| primary_key | update_ts |
|-------------|-----------|
| key_1       | 0         |
| key_2       | 0         |
| key_3       | 0         |

Transaction n=2 (APPEND):

| primary_key | update_ts |
|-------------|-----------|
| key_1       | 1         |
| key_2       | 1         |

How can I make this 'latest version' computation faster?


Solution

  • This is a common pattern that will benefit from bucketing.

    The gist of this is: write your output SNAPSHOT into buckets based on your primary_key column, where the expensive step of shuffling your much much larger output is skipped entirely.

    This means you will only have to exchange your new data to the buckets that already contain your prior history.

    Let's start from the initial state, where we are running on a prior-computed 'latest' version that was a slow SNAPSHOT:

    - output: raw_dataset
      input: external_jdbc_system
      hive_partitioning: none
      bucketing: none
      transactions:
        - SNAPSHOT
        - APPEND
        - APPEND
    - output: clean_dataset
      input: raw_dataset
      hive_partitioning: none
      bucketing: none
      transactions:
        - SNAPSHOT
        - SNAPSHOT
        - SNAPSHOT
    

    If we write out clean_dataset using bucketing over the primary_key column into a count of buckets calculated separately to fit the datascale we anticipate, we would need the following code:

    from transforms.api import transform, Input, Output
    import pyspark.sql.functions as F
    from pyspark.sql.functions import row_number
    from pyspark.sql.window import Window
    
    
    @transform(
        my_output=Output("/datasets/clean_dataset"),
        my_input=Input("/datasets/raw_dataset")
    )
    def my_compute_function(my_input, my_output):
    
        BUCKET_COUNT = 600
        PRIMARY_KEY = "primary_key"
        ORDER_COL = "update_ts"
    
        updated_keys = my_input.dataframe("added")
        last_written = my_output.dataframe("current")
    
        updated_keys.repartition(BUCKET_COUNT, PRIMARY_KEY)
    
        value_cols = [x for x in last_written.columns if x != PRIMARY_KEY]
    
        updated_keys = updated_keys.select(
          PRIMARY_KEY,
          *[F.col(x).alias("updated_keys_" + x) for x in value_cols]
        )
    
        last_written = last_written.select(
          PRIMARY_KEY,
          *[F.col(x).alias("last_written_" + x) for x in value_cols]
        )
    
        all_rows = updated_keys.join(last_written, PRIMARY_KEY, "fullouter")
        
        latest_df = all_rows.select(
          PRIMARY_KEY,
          *[F.coalesce(
              F.col("updated_keys_" + x),
              F.col("last_written_" + x)
            ).alias(x) for x in value_cols]
        )
    
        my_output.set_mode("replace")
    
        return my_output.write_dataframe(
            latest_df,
            bucket_cols=PRIMARY_KEY,
            bucket_count=BUCKET_COUNT,
            sort_by=ORDER_COL
        )
    

    When this runs, you'll notice in your query plan that the project step over the output no longer includes an exchange, which means it won't be shuffling that data. The only exchange you'll now see is on the input where it needs to distribute the changes in the exact same manner as the output was formatted (this is a very fast operation).

    This exchange is then preserved into the fullouter join step, where the join will then exploit this and run the 600 tasks very quickly. Finally, we maintain the format on the output by explicitly bucketing into the same number of buckets over the same columns as before.

    NOTE: with this approach, your file sizes in each bucket will grow over time and not take into account the need to increase bucket counts to keep things nicely sized. You will eventually hit a threshold with this technique where file sizes get above 128MB and you are no longer executing efficiently (the fix is to bump the BUCKET_COUNT value).

    Your output will now look like the following:

    - output: raw_dataset
      input: external_jdbc_system
      hive_partitioning: none
      bucketing: none
      transactions:
        - SNAPSHOT
        - APPEND
        - APPEND
    - output: clean_dataset
      input: raw_dataset
      hive_partitioning: none
      bucketing: BUCKET_COUNT by PRIMARY_KEY
      transactions:
        - SNAPSHOT
        - SNAPSHOT
        - SNAPSHOT