Search code examples
apache-sparkhivepysparkhadoop2

Building derived column using Spark transformations


I got a table record as stated below.

Id   Indicator     Date
1       R       2018-01-20
1       R       2018-10-21
1       P       2019-01-22
2       R       2018-02-28
2       P       2018-05-22
2       P       2019-03-05 

I need to pick the Ids that had more than two R indicator in the last one year and derive a new column called Marked_Flag as Y otherwise N. So the expected output should look like below,

Id  Marked_Flag 
1   Y
2   N

So what I did so far, I took the records in a dataset and then again build another dataset from that. The code looks like below.

Dataset<row> getIndicators = spark.sql("select id, count(indicator) as indi_count from source group by id having indicator = 'R'");

Dataset<row>getFlag = spark.sql("select id, case when indi_count > 1 then 'Y' else 'N' end as Marked_Flag" from getIndicators");

But my lead what this to be done using a single dataset and using Spark transformations. I am pretty new to Spark, any guidance or code snippet on this regard would be highly helpful.

Created two Datasets one to get the aggregation and another used the aggregated value to derive the new column.

Dataset<row> getIndicators = spark.sql("select id, count(indicator) as indi_count from source group by id having indicator = 'R'");

Dataset<row>getFlag = spark.sql("select id, case when indi_count > 1 then 'Y' else 'N' end as Marked_Flag" from getIndicators");

Input

Expected output


Solution

  • Try out the following. Note that I am using pyspark DataFrame here

    from pyspark.sql import SparkSession
    import pyspark.sql.functions as F
    spark = SparkSession.builder.getOrCreate()
    df = spark.createDataFrame([
    [1, "R", "2018-01-20"],
    [1, "R", "2018-10-21"],
    [1, "P", "2019-01-22"],
    [2, "R", "2018-02-28"],
    [2, "P", "2018-05-22"],
    [2, "P", "2019-03-05"]], ["Id", "Indicator","Date"])
    
    gr = df.filter(F.col("Indicator")=="R").groupBy("Id").agg(F.count("Indicator"))
    gr = gr.withColumn("Marked_Flag", F.when(F.col("count(Indicator)") > 1, "Y").otherwise('N')).drop("count(Indicator)")
    gr.show()
    
    # +---+-----------+
    # | Id|Marked_Flag|
    # +---+-----------+
    # |  1|          Y|
    # |  2|          N|
    # +---+-----------+
    #