Search code examples
pysparkpalantir-foundryfoundry-code-repositoriesfoundry-python-transform

Why do I see repeated materializations of a DataFrame in my build?


I'm executing the following code:

from pyspark.sql import types as T, functions as F, SparkSession
spark = SparkSession.builder.getOrCreate()

schema = T.StructType([
  T.StructField("col_1", T.IntegerType(), False),
  T.StructField("col_2", T.IntegerType(), False),
  T.StructField("measure_1", T.FloatType(), False),
  T.StructField("measure_2", T.FloatType(), False),
])
data = [
  {"col_1": 1, "col_2": 2, "measure_1": 0.5, "measure_2": 1.5},
  {"col_1": 2, "col_2": 3, "measure_1": 2.5, "measure_2": 3.5}
]

df = spark.createDataFrame(data, schema)
df.show()

"""
+-----+-----+---------+---------+
|col_1|col_2|measure_1|measure_2|
+-----+-----+---------+---------+
|    1|    2|      0.5|      1.5|
|    2|    3|      2.5|      3.5|
+-----+-----+---------+---------+
"""

group_cols = ["col_1", "col_2"]
measure_cols = ["measure_1", "measure_2"]
for col in measure_cols:
  stats = df.groupBy(group_cols).agg(
    F.max(col).alias("max_" + col),
    F.avg(col).alias("avg_" + col),
  )
  df = df.join(stats, group_cols)
df.show()

"""
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
|col_1|col_2|measure_1|measure_2|max_measure_1|avg_measure_1|max_measure_2|avg_measure_2|
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
|    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
|    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
"""

Now the problem arises if my initial df isn't so simple but is actually a series of joins or other operations. I notice when I look at my job that df seems to be derived several times as my groupBy operations execute. The simple query plan here is:


df.explain()
"""
>>> df.explain()
== Physical Plan ==
*(11) Project [col_1#26, col_2#27, measure_1#28, measure_2#29, max_measure_1#56, avg_measure_1#58, max_measure_2#80, avg_measure_2#82]
+- *(11) SortMergeJoin [col_1#26, col_2#27], [col_1#87, col_2#88], Inner
   :- *(5) Project [col_1#26, col_2#27, measure_1#28, measure_2#29, max_measure_1#56, avg_measure_1#58]
   :  +- *(5) SortMergeJoin [col_1#26, col_2#27], [col_1#63, col_2#64], Inner
   :     :- *(2) Sort [col_1#26 ASC NULLS FIRST, col_2#27 ASC NULLS FIRST], false, 0
   :     :  +- Exchange hashpartitioning(col_1#26, col_2#27, 200), ENSURE_REQUIREMENTS, [id=#276]
   :     :     +- *(1) Scan ExistingRDD[col_1#26,col_2#27,measure_1#28,measure_2#29]
   :     +- *(4) Sort [col_1#63 ASC NULLS FIRST, col_2#64 ASC NULLS FIRST], false, 0
   :        +- *(4) HashAggregate(keys=[col_1#63, col_2#64], functions=[max(measure_1#65), avg(cast(measure_1#65 as double))])
   :           +- Exchange hashpartitioning(col_1#63, col_2#64, 200), ENSURE_REQUIREMENTS, [id=#282]
   :              +- *(3) HashAggregate(keys=[col_1#63, col_2#64], functions=[partial_max(measure_1#65), partial_avg(cast(measure_1#65 as double))])
   :                 +- *(3) Project [col_1#63, col_2#64, measure_1#65]
   :                    +- *(3) Scan ExistingRDD[col_1#63,col_2#64,measure_1#65,measure_2#66]
   +- *(10) Sort [col_1#87 ASC NULLS FIRST, col_2#88 ASC NULLS FIRST], false, 0
      +- *(10) HashAggregate(keys=[col_1#87, col_2#88], functions=[max(measure_2#90), avg(cast(measure_2#90 as double))])
         +- *(10) HashAggregate(keys=[col_1#87, col_2#88], functions=[partial_max(measure_2#90), partial_avg(cast(measure_2#90 as double))])
            +- *(10) Project [col_1#87, col_2#88, measure_2#90]
               +- *(10) SortMergeJoin [col_1#87, col_2#88], [col_1#63, col_2#64], Inner
                  :- *(7) Sort [col_1#87 ASC NULLS FIRST, col_2#88 ASC NULLS FIRST], false, 0
                  :  +- Exchange hashpartitioning(col_1#87, col_2#88, 200), ENSURE_REQUIREMENTS, [id=#293]
                  :     +- *(6) Project [col_1#87, col_2#88, measure_2#90]
                  :        +- *(6) Scan ExistingRDD[col_1#87,col_2#88,measure_1#89,measure_2#90]
                  +- *(9) Sort [col_1#63 ASC NULLS FIRST, col_2#64 ASC NULLS FIRST], false, 0
                     +- *(9) HashAggregate(keys=[col_1#63, col_2#64], functions=[])
                        +- Exchange hashpartitioning(col_1#63, col_2#64, 200), ENSURE_REQUIREMENTS, [id=#299]
                           +- *(8) HashAggregate(keys=[col_1#63, col_2#64], functions=[])
                              +- *(8) Project [col_1#63, col_2#64]
                                 +- *(8) Scan ExistingRDD[col_1#63,col_2#64,measure_1#65,measure_2#66]
"""

But if for instance I change my above code to make the initial df be the result of a join and union:

from pyspark.sql import types as T, functions as F, SparkSession
spark = SparkSession.builder.getOrCreate()

schema = T.StructType([
  T.StructField("col_1", T.IntegerType(), False),
  T.StructField("col_2", T.IntegerType(), False),
  T.StructField("measure_1", T.FloatType(), False),
  T.StructField("measure_2", T.FloatType(), False),
])
data = [
  {"col_1": 1, "col_2": 2, "measure_1": 0.5, "measure_2": 1.5},
  {"col_1": 2, "col_2": 3, "measure_1": 2.5, "measure_2": 3.5}
]

df = spark.createDataFrame(data, schema)

right_schema = T.StructType([
  T.StructField("col_1", T.IntegerType(), False)
])
right_data = [
  {"col_1": 1},
  {"col_1": 1},
  {"col_1": 2},
  {"col_1": 2}
]
right_df = spark.createDataFrame(right_data, right_schema)

df = df.unionByName(df)
df = df.join(right_df, on="col_1")
df.show()

"""
+-----+-----+---------+---------+
|col_1|col_2|measure_1|measure_2|
+-----+-----+---------+---------+
|    1|    2|      0.5|      1.5|
|    1|    2|      0.5|      1.5|
|    1|    2|      0.5|      1.5|
|    1|    2|      0.5|      1.5|
|    2|    3|      2.5|      3.5|
|    2|    3|      2.5|      3.5|
|    2|    3|      2.5|      3.5|
|    2|    3|      2.5|      3.5|
+-----+-----+---------+---------+
"""

df.explain()

"""
== Physical Plan ==
*(7) Project [col_1#299, col_2#300, measure_1#301, measure_2#302, col_2#354, measure_1#355, measure_2#356]
+- *(7) SortMergeJoin [col_1#299], [col_1#353], Inner
   :- *(3) Sort [col_1#299 ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(col_1#299, 200), ENSURE_REQUIREMENTS, [id=#595]
   :     +- Union
   :        :- *(1) Scan ExistingRDD[col_1#299,col_2#300,measure_1#301,measure_2#302]
   :        +- *(2) Scan ExistingRDD[col_1#299,col_2#300,measure_1#301,measure_2#302]
   +- *(6) Sort [col_1#353 ASC NULLS FIRST], false, 0
      +- ReusedExchange [col_1#353, col_2#354, measure_1#355, measure_2#356], Exchange hashpartitioning(col_1#299, 200), ENSURE_REQUIREMENTS, [id=#595]
"""

group_cols = ["col_1", "col_2"]
measure_cols = ["measure_1", "measure_2"]
for col in measure_cols:
  stats = df.groupBy(group_cols).agg(
    F.max(col).alias("max_" + col),
    F.avg(col).alias("avg_" + col),
  )
  df = df.join(stats, group_cols)
df.show()

"""
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
|col_1|col_2|measure_1|measure_2|max_measure_1|avg_measure_1|max_measure_2|avg_measure_2|
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
|    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
|    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
|    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
|    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
|    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
|    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
|    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
|    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
+-----+-----+---------+---------+-------------+-------------+-------------+-------------+
"""

df.explain()

"""
== Physical Plan ==
*(31) Project [col_1#404, col_2#405, measure_1#406, measure_2#407, max_measure_1#465, avg_measure_1#467, max_measure_2#489, avg_measure_2#491]
+- *(31) SortMergeJoin [col_1#404, col_2#405], [col_1#496, col_2#497], Inner
   :- *(15) Project [col_1#404, col_2#405, measure_1#406, measure_2#407, max_measure_1#465, avg_measure_1#467]
   :  +- *(15) SortMergeJoin [col_1#404, col_2#405], [col_1#472, col_2#473], Inner
   :     :- *(7) Sort [col_1#404 ASC NULLS FIRST, col_2#405 ASC NULLS FIRST], false, 0
   :     :  +- Exchange hashpartitioning(col_1#404, col_2#405, 200), ENSURE_REQUIREMENTS, [id=#1508]
   :     :     +- *(6) Project [col_1#404, col_2#405, measure_1#406, measure_2#407]
   :     :        +- *(6) SortMergeJoin [col_1#404], [col_1#412], Inner
   :     :           :- *(3) Sort [col_1#404 ASC NULLS FIRST], false, 0
   :     :           :  +- Exchange hashpartitioning(col_1#404, 200), ENSURE_REQUIREMENTS, [id=#1494]
   :     :           :     +- Union
   :     :           :        :- *(1) Scan ExistingRDD[col_1#404,col_2#405,measure_1#406,measure_2#407]
   :     :           :        +- *(2) Scan ExistingRDD[col_1#404,col_2#405,measure_1#406,measure_2#407]
   :     :           +- *(5) Sort [col_1#412 ASC NULLS FIRST], false, 0
   :     :              +- Exchange hashpartitioning(col_1#412, 200), ENSURE_REQUIREMENTS, [id=#1500]
   :     :                 +- *(4) Scan ExistingRDD[col_1#412]
   :     +- *(14) Sort [col_1#472 ASC NULLS FIRST, col_2#473 ASC NULLS FIRST], false, 0
   :        +- Exchange hashpartitioning(col_1#472, col_2#473, 200), ENSURE_REQUIREMENTS, [id=#1639]
   :           +- *(13) HashAggregate(keys=[col_1#472, col_2#473], functions=[max(measure_1#474), avg(cast(measure_1#474 as double))])
   :              +- *(13) HashAggregate(keys=[col_1#472, col_2#473], functions=[partial_max(measure_1#474), partial_avg(cast(measure_1#474 as double))])
   :                 +- *(13) Project [col_1#472, col_2#473, measure_1#474]
   :                    +- *(13) SortMergeJoin [col_1#472], [col_1#412], Inner
   :                       :- *(10) Sort [col_1#472 ASC NULLS FIRST], false, 0
   :                       :  +- Exchange hashpartitioning(col_1#472, 200), ENSURE_REQUIREMENTS, [id=#1516]
   :                       :     +- Union
   :                       :        :- *(8) Project [col_1#472, col_2#473, measure_1#474]
   :                       :        :  +- *(8) Scan ExistingRDD[col_1#472,col_2#473,measure_1#474,measure_2#475]
   :                       :        +- *(9) Project [col_1#472, col_2#473, measure_1#474]
   :                       :           +- *(9) Scan ExistingRDD[col_1#472,col_2#473,measure_1#474,measure_2#475]
   :                       +- *(12) Sort [col_1#412 ASC NULLS FIRST], false, 0
   :                          +- ReusedExchange [col_1#412], Exchange hashpartitioning(col_1#412, 200), ENSURE_REQUIREMENTS, [id=#1500]
   +- *(30) Sort [col_1#496 ASC NULLS FIRST, col_2#497 ASC NULLS FIRST], false, 0
      +- *(30) HashAggregate(keys=[col_1#496, col_2#497], functions=[max(measure_2#499), avg(cast(measure_2#499 as double))])
         +- *(30) HashAggregate(keys=[col_1#496, col_2#497], functions=[partial_max(measure_2#499), partial_avg(cast(measure_2#499 as double))])
            +- *(30) Project [col_1#496, col_2#497, measure_2#499]
               +- *(30) SortMergeJoin [col_1#496, col_2#497], [col_1#472, col_2#473], Inner
                  :- *(22) Sort [col_1#496 ASC NULLS FIRST, col_2#497 ASC NULLS FIRST], false, 0
                  :  +- Exchange hashpartitioning(col_1#496, col_2#497, 200), ENSURE_REQUIREMENTS, [id=#1660]
                  :     +- *(21) Project [col_1#496, col_2#497, measure_2#499]
                  :        +- *(21) SortMergeJoin [col_1#496], [col_1#412], Inner
                  :           :- *(18) Sort [col_1#496 ASC NULLS FIRST], false, 0
                  :           :  +- Exchange hashpartitioning(col_1#496, 200), ENSURE_REQUIREMENTS, [id=#1544]
                  :           :     +- Union
                  :           :        :- *(16) Project [col_1#496, col_2#497, measure_2#499]
                  :           :        :  +- *(16) Scan ExistingRDD[col_1#496,col_2#497,measure_1#498,measure_2#499]
                  :           :        +- *(17) Project [col_1#496, col_2#497, measure_2#499]
                  :           :           +- *(17) Scan ExistingRDD[col_1#496,col_2#497,measure_1#498,measure_2#499]
                  :           +- *(20) Sort [col_1#412 ASC NULLS FIRST], false, 0
                  :              +- ReusedExchange [col_1#412], Exchange hashpartitioning(col_1#412, 200), ENSURE_REQUIREMENTS, [id=#1500]
                  +- *(29) Sort [col_1#472 ASC NULLS FIRST, col_2#473 ASC NULLS FIRST], false, 0
                     +- Exchange hashpartitioning(col_1#472, col_2#473, 200), ENSURE_REQUIREMENTS, [id=#1707]
                        +- *(28) HashAggregate(keys=[col_1#472, col_2#473], functions=[])
                           +- *(28) HashAggregate(keys=[col_1#472, col_2#473], functions=[])
                              +- *(28) Project [col_1#472, col_2#473]
                                 +- *(28) SortMergeJoin [col_1#472], [col_1#412], Inner
                                    :- *(25) Sort [col_1#472 ASC NULLS FIRST], false, 0
                                    :  +- Exchange hashpartitioning(col_1#472, 200), ENSURE_REQUIREMENTS, [id=#1566]
                                    :     +- Union
                                    :        :- *(23) Project [col_1#472, col_2#473]
                                    :        :  +- *(23) Scan ExistingRDD[col_1#472,col_2#473,measure_1#474,measure_2#475]
                                    :        +- *(24) Project [col_1#472, col_2#473]
                                    :           +- *(24) Scan ExistingRDD[col_1#472,col_2#473,measure_1#474,measure_2#475]
                                    +- *(27) Sort [col_1#412 ASC NULLS FIRST], false, 0
                                       +- ReusedExchange [col_1#412], Exchange hashpartitioning(col_1#412, 200), ENSURE_REQUIREMENTS, [id=#1500]
"""

You can see in the query plan that the join + union is derived several times, which is reflected in my job's execution report where I see the stage with the identical number of tasks run again and again.

How can I stop this re-derivation from happening?


Solution

  • The inner loop of your transform where you join + derive columns several times against a base DataFrame would benefit from PySpark's .cache() function. This explicitly instructs Spark to hold on to the derived DataFrame and not re-compute it. This means you will compute the initial union + join a single time, then re-use the DataFrame in subsequent transformations.

    This is a one-line addition that will benefit your execution massively.

    from pyspark.sql import types as T, functions as F, SparkSession
    spark = SparkSession.builder.getOrCreate()
    
    schema = T.StructType([
      T.StructField("col_1", T.IntegerType(), False),
      T.StructField("col_2", T.IntegerType(), False),
      T.StructField("measure_1", T.FloatType(), False),
      T.StructField("measure_2", T.FloatType(), False),
    ])
    data = [
      {"col_1": 1, "col_2": 2, "measure_1": 0.5, "measure_2": 1.5},
      {"col_1": 2, "col_2": 3, "measure_1": 2.5, "measure_2": 3.5}
    ]
    
    df = spark.createDataFrame(data, schema)
    
    right_schema = T.StructType([
      T.StructField("col_1", T.IntegerType(), False)
    ])
    right_data = [
      {"col_1": 1},
      {"col_1": 1},
      {"col_1": 2},
      {"col_1": 2}
    ]
    right_df = spark.createDataFrame(right_data, right_schema)
    
    df = df.unionByName(df)
    df = df.join(right_df, on="col_1")
    
    # ========= Added this line BEFORE the loop
    df = df.cache()
    # =========
    
    group_cols = ["col_1", "col_2"]
    measure_cols = ["measure_1", "measure_2"]
    for col in measure_cols:
      stats = df.groupBy(group_cols).agg(
        F.max(col).alias("max_" + col),
        F.avg(col).alias("avg_" + col),
      )
      df = df.join(stats, group_cols)
    df.show()
    
    """
    +-----+-----+---------+---------+-------------+-------------+-------------+-------------+
    |col_1|col_2|measure_1|measure_2|max_measure_1|avg_measure_1|max_measure_2|avg_measure_2|
    +-----+-----+---------+---------+-------------+-------------+-------------+-------------+
    |    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
    |    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
    |    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
    |    2|    3|      2.5|      3.5|          2.5|          2.5|          3.5|          3.5|
    |    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
    |    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
    |    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
    |    1|    2|      0.5|      1.5|          0.5|          0.5|          1.5|          1.5|
    +-----+-----+---------+---------+-------------+-------------+-------------+-------------+
    """
    
    df.explain()
    """
    >>> df.explain()
    == Physical Plan ==
    *(4) Project [col_1#1265, col_2#1266, measure_1#1267, measure_2#1268, max_measure_1#1312, avg_measure_1#1314, max_measure_2#1336, avg_measure_2#1338]
    +- *(4) BroadcastHashJoin [col_1#1265, col_2#1266], [col_1#1343, col_2#1344], Inner, BuildRight, false
       :- *(4) Project [col_1#1265, col_2#1266, measure_1#1267, measure_2#1268, max_measure_1#1312, avg_measure_1#1314]
       :  +- *(4) BroadcastHashJoin [col_1#1265, col_2#1266], [col_1#1319, col_2#1320], Inner, BuildLeft, false
       :     :- BroadcastExchange HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [id=#2439]
       :     :  +- *(1) ColumnarToRow
       :     :     +- InMemoryTableScan [col_1#1265, col_2#1266, measure_1#1267, measure_2#1268]
       :     :           +- InMemoryRelation [col_1#1265, col_2#1266, measure_1#1267, measure_2#1268], StorageLevel(disk, memory, deserialized, 1 replicas)
       :     :                 +- *(6) Project [col_1#1265, col_2#1266, measure_1#1267, measure_2#1268]
       :     :                    +- *(6) SortMergeJoin [col_1#1265], [col_1#1273], Inner
       :     :                       :- *(3) Sort [col_1#1265 ASC NULLS FIRST], false, 0
       :     :                       :  +- Exchange hashpartitioning(col_1#1265, 200), ENSURE_REQUIREMENTS, [id=#2169]
       :     :                       :     +- Union
       :     :                       :        :- *(1) Scan ExistingRDD[col_1#1265,col_2#1266,measure_1#1267,measure_2#1268]
       :     :                       :        +- *(2) Scan ExistingRDD[col_1#1265,col_2#1266,measure_1#1267,measure_2#1268]
       :     :                       +- *(5) Sort [col_1#1273 ASC NULLS FIRST], false, 0
       :     :                          +- Exchange hashpartitioning(col_1#1273, 200), ENSURE_REQUIREMENTS, [id=#2175]
       :     :                             +- *(4) Scan ExistingRDD[col_1#1273]
       :     +- *(4) HashAggregate(keys=[col_1#1319, col_2#1320], functions=[max(measure_1#1321), avg(cast(measure_1#1321 as double))])
       :        +- *(4) HashAggregate(keys=[col_1#1319, col_2#1320], functions=[partial_max(measure_1#1321), partial_avg(cast(measure_1#1321 as double))])
       :           +- *(4) ColumnarToRow
       :              +- InMemoryTableScan [col_1#1319, col_2#1320, measure_1#1321]
       :                    +- InMemoryRelation [col_1#1319, col_2#1320, measure_1#1321, measure_2#1322], StorageLevel(disk, memory, deserialized, 1 replicas)
       :                          +- *(6) Project [col_1#1265, col_2#1266, measure_1#1267, measure_2#1268]
       :                             +- *(6) SortMergeJoin [col_1#1265], [col_1#1273], Inner
       :                                :- *(3) Sort [col_1#1265 ASC NULLS FIRST], false, 0
       :                                :  +- Exchange hashpartitioning(col_1#1265, 200), ENSURE_REQUIREMENTS, [id=#2169]
       :                                :     +- Union
       :                                :        :- *(1) Scan ExistingRDD[col_1#1265,col_2#1266,measure_1#1267,measure_2#1268]
       :                                :        +- *(2) Scan ExistingRDD[col_1#1265,col_2#1266,measure_1#1267,measure_2#1268]
       :                                +- *(5) Sort [col_1#1273 ASC NULLS FIRST], false, 0
       :                                   +- Exchange hashpartitioning(col_1#1273, 200), ENSURE_REQUIREMENTS, [id=#2175]
       :                                      +- *(4) Scan ExistingRDD[col_1#1273]
       +- BroadcastExchange HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [id=#2461]
          +- *(3) HashAggregate(keys=[col_1#1343, col_2#1344], functions=[max(measure_2#1346), avg(cast(measure_2#1346 as double))])
             +- *(3) HashAggregate(keys=[col_1#1343, col_2#1344], functions=[partial_max(measure_2#1346), partial_avg(cast(measure_2#1346 as double))])
                +- *(3) Project [col_1#1343, col_2#1344, measure_2#1346]
                   +- *(3) BroadcastHashJoin [col_1#1343, col_2#1344], [col_1#1319, col_2#1320], Inner, BuildRight, false
                      :- *(3) ColumnarToRow
                      :  +- InMemoryTableScan [col_1#1343, col_2#1344, measure_2#1346]
                      :        +- InMemoryRelation [col_1#1343, col_2#1344, measure_1#1345, measure_2#1346], StorageLevel(disk, memory, deserialized, 1 replicas)
                      :              +- *(6) Project [col_1#1265, col_2#1266, measure_1#1267, measure_2#1268]
                      :                 +- *(6) SortMergeJoin [col_1#1265], [col_1#1273], Inner
                      :                    :- *(3) Sort [col_1#1265 ASC NULLS FIRST], false, 0
                      :                    :  +- Exchange hashpartitioning(col_1#1265, 200), ENSURE_REQUIREMENTS, [id=#2169]
                      :                    :     +- Union
                      :                    :        :- *(1) Scan ExistingRDD[col_1#1265,col_2#1266,measure_1#1267,measure_2#1268]
                      :                    :        +- *(2) Scan ExistingRDD[col_1#1265,col_2#1266,measure_1#1267,measure_2#1268]
                      :                    +- *(5) Sort [col_1#1273 ASC NULLS FIRST], false, 0
                      :                       +- Exchange hashpartitioning(col_1#1273, 200), ENSURE_REQUIREMENTS, [id=#2175]
                      :                          +- *(4) Scan ExistingRDD[col_1#1273]
                      +- BroadcastExchange HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [id=#2454]
                         +- *(2) HashAggregate(keys=[col_1#1319, col_2#1320], functions=[])
                            +- *(2) HashAggregate(keys=[col_1#1319, col_2#1320], functions=[])
                               +- *(2) ColumnarToRow
                                  +- InMemoryTableScan [col_1#1319, col_2#1320]
                                        +- InMemoryRelation [col_1#1319, col_2#1320, measure_1#1321, measure_2#1322], StorageLevel(disk, memory, deserialized, 1 replicas)
                                              +- *(6) Project [col_1#1265, col_2#1266, measure_1#1267, measure_2#1268]
                                                 +- *(6) SortMergeJoin [col_1#1265], [col_1#1273], Inner
                                                    :- *(3) Sort [col_1#1265 ASC NULLS FIRST], false, 0
                                                    :  +- Exchange hashpartitioning(col_1#1265, 200), ENSURE_REQUIREMENTS, [id=#2169]
                                                    :     +- Union
                                                    :        :- *(1) Scan ExistingRDD[col_1#1265,col_2#1266,measure_1#1267,measure_2#1268]
                                                    :        +- *(2) Scan ExistingRDD[col_1#1265,col_2#1266,measure_1#1267,measure_2#1268]
                                                    +- *(5) Sort [col_1#1273 ASC NULLS FIRST], false, 0
                                                       +- Exchange hashpartitioning(col_1#1273, 200), ENSURE_REQUIREMENTS, [id=#2175]
                                                          +- *(4) Scan ExistingRDD[col_1#1273]
    """
    

    You can now see in the query plan that an InMemoryTableRelation is used in place of several recurring shuffles, and your job execution will reflect as much.

    Note: .cache() doesn't change your query plan and won't truncate it at all, it simply changes the manner in which your data is created and re-used.