Search code examples
apache-sparkcassandradatastaxdatastax-enterprise

How can I achieve server side filtering with the join in spark dataframe api


This is a part of my spark app. The first part is the part where I get all the articles within the last 1 hour and the second part of the code grabs all these articles comments. The third part adds the comments to the articles. The problem is that the articles.map(lambda x:(x.id,x.id)).join(axes) part is too slow, it takes around 1 minute. I would like to improve this to 10 seconds or even less but don't know how to? Thanks for your reply.

articles = sqlContext.read.format("org.apache.spark.sql.cassandra").options(table="articles", keyspace=source).load() \
                        .map(lambda x:x).filter(lambda x:x.created_at!=None).filter(lambda x:x.created_at>=datetime.now()-timedelta(hours=1) and x.created_at<=datetime.now()-timedelta(hours=0)).cache()

axes = sqlContext.read.format("org.apache.spark.sql.cassandra").options(table="axes", keyspace=source).load().map(lambda x:(x.article,x))

speed_rdd = articles.map(lambda x:(x.id,x.id)).join(axes)

EDIT

This is my new code, which I changed according to your suggestions. It is now already 2 times as fast as before, so thanks for that ;). Just another improvement I would like to make with the last part of my code in the axes part, which is still too slow and needs 38 seconds for 30 million data:

range_expr = col("created_at").between(
                            datetime.now()-timedelta(hours=timespan),
                            datetime.now()-timedelta(hours=time_delta(timespan))
                        )
        article_ids = sqlContext.read.format("org.apache.spark.sql.cassandra").options(table="article_by_created_at", keyspace=source).load().where(range_expr).select('article','created_at').persist()


        axes = sqlContext.read.format("org.apache.spark.sql.cassandra").options(table="axes", keyspace=source).load()

I tried this here (which should substitute the last axes part of my code) and this is also the solution I would like to have but it doesn't seem to work properly:

in_expr = col("article").isin(article_ids.collect())
        axes = sqlContext.read.format("org.apache.spark.sql.cassandra").options(table="axes", keyspace=source).load().where(in_expr)

I always get this error message:

in_expr = col("article").isin(article_ids.collect())
Traceback (most recent call last):                                              
  File "<stdin>", line 1, in <module>
TypeError: 'Column' object is not callable

Thanks for your help.


Solution

  • As mentioned before if you want to achieve reasonable performance don't convert your data to RDD. It not not only makes optimizations like predicate pushdown impossible, but also introduces as huge overhead of moving data out of JVM to Python.

    Instead you should use use SQL expressions / DataFrame API in a way similar to this:

    from pyspark.sql.functions import col, expr, current_timestamp
    
    range_expr = col("created_at").between(
        current_timestamp() - expr("INTERVAL 1 HOUR"),
        current_timestamp())
    
    articles = (sqlContext.read.format("org.apache.spark.sql.cassandra")
        .options(...).load()
        .where(col("created_at").isNotNull())  # This is not really required
        .where(range_expr))
    

    It should be also possible to formulate predicate expression using standard Python utilities as you've done before:

    import datetime
    
    range_expr = col("created_at").between(
        datetime.datetime.now() - datetime.timedelta(hours=1),
        datetime.datetime.now()
    )
    

    Subsequent join should be performed without moving data out of data frame as well:

    axes = (sqlContext.read.format("org.apache.spark.sql.cassandra")
        .options(...)
        .load())
    
    articles.join(axes, ["id"])