Search code examples
pythonapache-sparkpysparkapache-spark-sqlmodulus

How to get Modulo of a String in Pyspark


I'd like to calculate for an alphanumeric DataFrame column the relative modulo.

In pure Python I could do something like int(str, base) to convert it to a numeric value. Then simply applying the modulo %.

For example:

>>> int('5c43466dc6d2870001fk8205', 24) % 64
5L

Of course I'd like to avoid a UDF in Python, using only Spark functions possibly.

For example my data source can be something like this:

df = spark.createDataFrame(
    [
        '5c43466dc6d2870001fk8205', 
        '5c43466dc6d2870001fk8206', 
        '5c43466dc6d2870001fk8207'
    ], 
    StringType()
)

I'd like a new column with values [5L, 6L, 7L]


Solution

  • As @EnzoBnl pointed out, there is a function pyspark.sql.functions.conv which will:

    Convert a number in a string column from one base to another.

    But as he pointed out, your numbers are too big for this function to work properly.

    However, you can use some math to simplify the calculation to something tractable.

    It can be shown1 that a base-24 digit number mod 64 is equivalent to the last two digits of that number modulus 64. That is, you can get the desired output with the following code:

    from pyspark.sql.functions import conv, lit, substring
    
    df.withColumn(
        "mod", 
        conv(substring("value", -2, 2), 24, 10).cast("long") % lit(64).cast("bigint")
    ).show(truncate=False)
    #+------------------------+---+
    #|value                   |mod|
    #+------------------------+---+
    #|5c43466dc6d2870001fk8205|5  |
    #|5c43466dc6d2870001fk8206|6  |
    #|5c43466dc6d2870001fk8207|7  |
    #+------------------------+---+
    

    The casting to long is required and I had a source to explain why but I can't seem to find it at the moment.


    Proof of Claim 1: If d is a base-24 representation of a number, then d % 64 = d_low % 64, where d_low represents the two least significant digits of d.

    Let's call our base-24 number d. If d has n digits, it can be represented in decimal (base-10) as follows:

    d = sum( di * 24**i for i in range(n) )
    

    Where di represents the ith digit in d in base-10.

    We can equivalently write this sum as the sum of the lower 2 digits (2 least significant digits) and the upper n-2 digits (given n > 2):

    d = sum( di * 24**i for i in range(2) ) + sum( di * 24**i for i in range(2, n) )
    #   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    #           let's call this d_low                  let's call this d_high
    
    d = d_low + d_high
    

    Observe that d_high can be simplified by factoring out 24**2

    d_high = (24**2) * sum( di * 24**(i-2) for i in range(2, n) )
    #                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    #                   for simplicity, let's call this x
    
    d_high = (24**2) * x
    

    Thus we have:

    d = d_low + (24**2) * x
    

    Now the number you want to calculate is d % 64.

    d % 64 = (d_low + (24**2) * x) % 64
    

    As shown here, (x + y) % z = ( x % z + y % z ) % z, so the above can be written as:

    d % 64 = (d_low % 64 + ((24**2) * x) % 64) % 64
    

    Now observe that 24**2 is an even multiple of 64 (because they both contain 2**6.

    24**2=((2**3)*3)**2=((2**6)*(3**2))=64*9`. 
    

    Thus (24**2) % 64 = 0. It follows then that ((24**2) * x) % 64 = 0.

    Consequently we can now write:

    d % 64 = (d_low % 64 + 0 % 64) % 64
           = (d_low % 64 + 0) % 64
           = d_low % 64