Search code examples
apache-sparkazure-databricks

how to store spark dataframe schema in a variable in databricks notebook?


using spark scala azure databricks.

i have a dataframe(df1) with 100+ columns. i need to create another dataframe(df2) and want to have same schema for that also. How can i store the schema of df1 into a variable and apply to df2

val inputDF = Seq(("00163E0F765C1ED79593228BF70CEE41" ,"PD PUMPS")
                       ,("00164E0F775C1ED79593228BF70CEE42" ,"PD PUMPS")
                       ,("00165E0F785C1ED79593228BF70CEE43" ,"PD PUMPS")
                       ,("00166E0F795C1ED79593228BF70CEE44" ,"PD PUMPS")
                       ,("00167E0F405C1ED79593228BF70CEE45" ,"PD PUMPS")
                   ).toDF("objectID")
val expectedDF = Seq(("00163E0F765C1ED79593228BF70CEE41" ,"PD PUMPS1")
                       ,("00164E0F775C1ED79593228BF70CEE42" ,"PD PUMPS1")
                       ,("00165E0F785C1ED79593228BF70CEE43" ,"PD PUMPS1")
                       ,("00166E0F795C1ED79593228BF70CEE44" ,"PD PUMPS1")
                       ,("00167E0F405C1ED79593228BF70CEE45" ,"PD PUMPS1")
                       ).toDF("objectID","equipmentName", inputDF.schema)

purpose of doing this is : I am actually writing unit test case. I have one function which is adding a column into the passed dataframe(parameterised). So i need to create inputDF and then i need to create expectedDF with 1 more column in inputDF.


Solution

  • Given that for the inputDF generation you should use a sequence of just one column (or pass two string in the toDF method), I would do as follow:

    val inputDF = Seq(("00163E0F765C1ED79593228BF70CEE41")
      ,("00164E0F775C1ED79593228BF70CEE42")
      ,("00165E0F785C1ED79593228BF70CEE43")
      ,("00166E0F795C1ED79593228BF70CEE44")
      ,("00167E0F405C1ED79593228BF70CEE45")
    ).toDF("objectID")
    
    val seq = Seq(("00163E0F765C1ED79593228BF70CEE41" ,"PD PUMPS1")
      ,("00164E0F775C1ED79593228BF70CEE42" ,"PD PUMPS1")
      ,("00165E0F785C1ED79593228BF70CEE43" ,"PD PUMPS1")
      ,("00166E0F795C1ED79593228BF70CEE44" ,"PD PUMPS1")
      ,("00167E0F405C1ED79593228BF70CEE45" ,"PD PUMPS1")
    )
    
    val rdd = spark.sparkContext.parallelize(seq)
    
    val rows: RDD[Row] = rdd.map((row: (String, String)) => {
      Row(row.productIterator.toList:_*)
    })
    
    val expectedDF = spark.createDataFrame(rows,
      inputDF.schema.add(StructField("NewColumn", org.apache.spark.sql.types.StringType )))
    
    df.show()
    

    The idea is to create a Dataframe from the sequence with the createDataFrame and pass it the schema (i.e. a StructType) of the old data frame with the addition of one entry (a StructField).