Search code examples
databasepostgresqlhadoopapache-spark

Primary keys with Apache Spark


I have a JDBC connection with Apache Spark and PostgreSQL and want to insert some data into my database. When I use append mode, I need to specify id for each DataFrame.Row. Is there any way for Spark to create primary keys?


Solution

  • Scala:

    If all you need is unique numbers you can use zipWithUniqueId and recreate DataFrame. First some imports and dummy data:

    import sqlContext.implicits._
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.types.{StructType, StructField, LongType}
    
    val df = sc.parallelize(Seq(
        ("a", -1.0), ("b", -2.0), ("c", -3.0))).toDF("foo", "bar")
    

    Extract schema for further usage:

    val schema = df.schema
    

    Add id field:

    val rows = df.rdd.zipWithUniqueId.map{
       case (r: Row, id: Long) => Row.fromSeq(id +: r.toSeq)}
    

    Create DataFrame:

    val dfWithPK = sqlContext.createDataFrame(
      rows, StructType(StructField("id", LongType, false) +: schema.fields))
    

    The same thing in Python:

    from pyspark.sql import Row
    from pyspark.sql.types import StructField, StructType, LongType
    
    row = Row("foo", "bar")
    row_with_index = Row(*["id"] + df.columns)
    
    df = sc.parallelize([row("a", -1.0), row("b", -2.0), row("c", -3.0)]).toDF()
    
    def make_row(columns):
        def _make_row(row, uid):
            row_dict = row.asDict()
            return row_with_index(*[uid] + [row_dict.get(c) for c in columns])
        return _make_row
    
    f = make_row(df.columns)
    
    df_with_pk = (df.rdd
        .zipWithUniqueId()
        .map(lambda x: f(*x))
        .toDF(StructType([StructField("id", LongType(), False)] + df.schema.fields)))
    

    If you prefer consecutive number your can replace zipWithUniqueId with zipWithIndex but it is a little bit more expensive.

    Directly with DataFrame API:

    (universal Scala, Python, Java, R with pretty much the same syntax)

    Previously I've missed monotonicallyIncreasingId function which should work just fine as long as you don't require consecutive numbers:

    import org.apache.spark.sql.functions.monotonicallyIncreasingId
    
    df.withColumn("id", monotonicallyIncreasingId).show()
    // +---+----+-----------+
    // |foo| bar|         id|
    // +---+----+-----------+
    // |  a|-1.0|17179869184|
    // |  b|-2.0|42949672960|
    // |  c|-3.0|60129542144|
    // +---+----+-----------+
    

    While useful monotonicallyIncreasingId is non-deterministic. Not only ids may be different from execution to execution but without additional tricks cannot be used to identify rows when subsequent operations contain filters.

    Note:

    It is also possible to use rowNumber window function:

    from pyspark.sql.window import Window
    from pyspark.sql.functions import rowNumber
    
    w = Window().orderBy()
    df.withColumn("id", rowNumber().over(w)).show()
    

    Unfortunately:

    WARN Window: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.

    So unless you have a natural way to partition your data and ensure uniqueness is not particularly useful at this moment.