Search code examples
arraysdataframepysparkuser-defined-functionscontains

Pyspark: check if a tuple is contained in a list of tuples


I am trying to analysis the reliability of my data from 2 separate sources (A and B). Since the range of fields is rather unequal I am focusing on common fields and run a comparison.

Here I selected the price and quantity and want to ensure the tuple [priceA, quantityA] is contained in my list of tuples [[price1B, quantity1B], [price2B, quantity2B], .. ] from source B.

I tried to create a udf to do so looking at other references, but I have just started with Pyspark and I don't really undertsand how to define my udf and the appropriate DataType to specify in the given case.

I have 2 dataframe for my 2 separate sources

I appended for each df a new column "combined" : StructField(combined_a,ArrayType(IntegerType,true),false)))

df_a = df_a.withColumn("combined_a", array("Quantity", "PRICE"))

and created a list of unique tuples :

list_a = list(df_a.select("combined_a").distinct().toPandas()["combined_a"])

output list_a

list_a = [ [81.0, 100.0], [56.0, 6.0], [10000.0, 45.32], [42.0, 6.0] ...]

I couldn't find any built-in functions that could satisfy my request : I want to append a new column "combinaison_in_b" of Boolean type. tried:

df_a = df_a.withColumn('combinaison_in_b_found' , col('combined_a').isin(list_b))

Returns following error

An error occurred while calling z:org.apache.spark.sql.functions.lit.
: java.lang.RuntimeException: Unsupported literal type class java.util.ArrayList [50, 51]

went on with a udf. tried:

def IsInDataframe(combined_a , list_b):
    found = TRUE
    for c in combined_a
        if c not in list_b:
            found = False
        if found:
            return True
        else:
            return False

def udf_append(list_b):
    return udf(lambda combined_a : IsInDataframe(combined_a , list_b))

df_a.withColumn("combinaison_in_b_found", udf_append(list_b)(col("combined_a"))).cast('boolean')

(udf syntax taken from pyspark how do we check if a column value is contained in a list

I would really appreciate, if someone could explain the part where it says return udf)

I would like as output my df with additional column "combinaison_in_b_found" True/False.

_______________________________________________
id |    combined_a    | combinaison_in_b_found
1  |  [81.0, 100.0]   |    false
2  |  [56.0, 6.0]     |    true
...

Solution

  • Try this:

    df_a = spark.createDataFrame([(1,[81.0, 100.0]), (1, [56.0, 6.0]),(3,[77.0, 88.0]), (4,[42., 8.])], ('id', 'combined_a') )
    df_a.show()
    list_b = [ [81.0, 100.0], [56.0, 6.0], [10000.0, 45.32], [42.0, 6.0]]
    print('list_b: {}'.format(list_b))
    my_udf = udf(lambda pair: 'true' if pair in list_b else 'false', StringType())
    df_a = df_a.withColumn('combinaison_in_b_found', my_udf(df_a['combined_a']))
    df_a.show()
    

    Here's the output:

    +---+-------------+
    | id|   combined_a|
    +---+-------------+
    |  1|[81.0, 100.0]|
    |  1|  [56.0, 6.0]|
    |  3| [77.0, 88.0]|
    |  4|  [42.0, 8.0]|
    +---+-------------+
    
    list_b: [[81.0, 100.0], [56.0, 6.0], [10000.0, 45.32], [42.0, 6.0]]
    +---+-------------+----------------------+
    | id|   combined_a|combinaison_in_b_found|
    +---+-------------+----------------------+
    |  1|[81.0, 100.0]|                  true|
    |  1|  [56.0, 6.0]|                  true|
    |  3| [77.0, 88.0]|                 false|
    |  4|  [42.0, 8.0]|                 false|
    +---+-------------+----------------------+