Search code examples
apache-sparkcassandrapysparkspark-cassandra-connector

PySpark + Cassandra: Getting distinct values of partition key


I'm trying to get the distinct values of the partition key of a cassandra table in pyspark. However, pyspark seems not to understand me and fully iterates all data (which is a lot) instead of querying the index.

This is the code I use, which looks pretty straightforward to me:

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Spark! This town not big enough for the two of us.") \
    .getOrCreate()

ct = spark.read\
    .format("org.apache.spark.sql.cassandra")\
    .options(table="avt_sensor_data", keyspace="ipe_smart_meter")\
    .load()

all_sensors = ct.select("machine_name", "sensor_name")\
    .distinct() \
    .collect()

The columns "machine_name" and "sensor_name" together form the partition key (see below for the complete schema). In my opinion, this should be super-fast, and in fact, if I execute this query in cql it takes only a couple of seconds:

select distinct machine_name,sensor_name from ipe_smart_meter.avt_sensor_data;

However, the spark job would take about 10 hours to complete. From what spark tells me about its plans, it looks like it really wants to iterate all the data:

== Physical Plan ==
*HashAggregate(keys=[machine_name#0, sensor_name#1], functions=[], output=[machine_name#0, sensor_name#1])
+- Exchange hashpartitioning(machine_name#0, sensor_name#1, 200)
   +- *HashAggregate(keys=[machine_name#0, sensor_name#1], functions=[], output=[machine_name#0, sensor_name#1])
      +- *Scan org.apache.spark.sql.cassandra.CassandraSourceRelation@2ee2f21d [machine_name#0,sensor_name#1] ReadSchema: struct<machine_name:string,sensor_name:string>

I'm not an expert, but that doesn't look like "use the cassandra index" to me.

What am I doing wrong? Is there any way of telling spark to delegate the task of getting the distinct values from cassandra? Any help would be greatly appreciated!

If that helps, here is a schema description of the underlying cassandra table:

CREATE KEYSPACE ipe_smart_meter WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '2'}  AND durable_writes = true;

CREATE TABLE ipe_smart_meter.avt_sensor_data (
    machine_name text,
    sensor_name text,
    ts timestamp,
    id bigint,
    value double,
    PRIMARY KEY ((machine_name, sensor_name), ts)
) WITH CLUSTERING ORDER BY (ts DESC)
    AND bloom_filter_fp_chance = 0.01
    AND caching = {'keys': 'ALL', 'rows_per_partition': 'NONE'}
    AND comment = '[PRODUCTION] Table for raw data from AVT smart meters.'
    AND compaction = {'class': 'org.apache.cassandra.db.compaction.DateTieredCompactionStrategy', 'max_threshold': '32', 'min_threshold': '4'}
    AND compression = {'chunk_length_in_kb': '64', 'class': 'org.apache.cassandra.io.compress.LZ4Compressor'}
    AND crc_check_chance = 1.0
    AND dclocal_read_repair_chance = 0.1
    AND default_time_to_live = 0
    AND gc_grace_seconds = 864000
    AND max_index_interval = 2048
    AND memtable_flush_period_in_ms = 0
    AND min_index_interval = 128
    AND read_repair_chance = 0.0
    AND speculative_retry = '99PERCENTILE';

Solution

  • It seems automatic cassandra server-side pushdown-predicate works only when selecting, filtering or ordering.

    https://github.com/datastax/spark-cassandra-connector/blob/master/doc/14_data_frames.md

    So, in case of your distinct(), spark gets all rows and then, does distinct().

    Solution 1

    You say your cql select distinct... is already super-fast. I guess there are relatively small number of the partition keys (the combination of machine_name and sensor_name) and so many 'ts'.

    So, most simple solution is just to use cql (for example, cassandra-driver).

    Solution 2

    Since cassandra is a query-first database, just create one more table, which has only partition keys required by your distinct query.

    CREATE TABLE ipe_smart_meter.avt_sensor_name_machine_name (
        machine_name text,
        sensor_name text,
        PRIMARY KEY ((machine_name, sensor_name))
    );
    

    Then, everytime you insert a row into your original table, insert the machine_name and sensor_name into the new table. Since it has only partition keys, this is a natural distinct table for your query. Just get all rows. Maybe super-fast. No need to distinct process.

    Solution 3

    I think solution-2 is best. But if you don't want to do two inserts for one record, one more solution is to change your table and create one materialized-view table.

    CREATE TABLE ipe_smart_meter.ipe_smart_meter.avt_sensor_data (
        machine_name text,
        sensor_name text,
        ts timestamp,
        id bigint,
        value double,
        dist_hint_num smallint,
        PRIMARY KEY ((machine_name, sensor_name), ts)
    ) WITH CLUSTERING ORDER BY (ts DESC)
    ;
    
    CREATE MATERIALIZED VIEW IF NOT EXISTS ipe_smart_meter.avt_sensor_data_mv AS
      SELECT
        machine_name
        ,sensor_name
        ,ts
        ,dist_hint_num
      FROM ipe_smart_meter.avt_sensor_data
      WHERE
        machine_name IS NOT NULL
        AND sensor_name IS NOT NULL
        AND ts IS NOT NULL
        AND dist_hint_num IS NOT NULL
      PRIMARY KEY ((dist_hint_num), machine_name, sensor_name, ts)
      WITH
      AND CLUSTERING ORDER BY (machine_name ASC, sensor_name DESC, ts DESC)
    ;
    

    The dist_hint_num column is used to limit the total number of partitions for your query to iterate, and distribute records.

    For example, from 0 to 15. Random integer random.randint(0, 15) or hash-based integer hash_func(machine_name + sensor_name) % 16is ok. Then, when you query as follows. cassandra gets all records from only 16 partitions, which may be more efficient than your current situation.

    But, anyway, all records have to be read and then distinct() (shuffle happens). Not space efficient. I think this is not a good solution.

    functools.reduce(
        lambda df, dist_hint_num: df.union(
            other=spark_session.read.format(
                'org.apache.spark.sql.cassandra',
            ).options(
                keyspace='ipe_smart_meter',
                table='avt_sensor_data_mv',
            ).load().filter(
                col('dist_hint_num') == expr(
                    f'CAST({dist_hint_num} AS SMALLINT)'
                )
            ).select(
                col('machine_name'),
                col('sensor_name'),
            ),
        ),
        range(0, 16),
        spark_session.createDataFrame(
            data=(),
            schema=StructType(
                fields=(
                    StructField(
                        name='machine_name',
                        dataType=StringType(),
                        nullable=False,
                    ),
                    StructField(
                        name='sensor_name',
                        dataType=StringType(),
                        nullable=False,
                    ),
                ),
            ),
        ),
    ).distinct().persist().alias(
        'df_all_machine_sensor',
    )