Search code examples
sql-serverapache-sparkpysparkapache-spark-sql

pyspark data type translation to sql server data types on df.write


How do pyspark data types get translated to sql server data types on df.write() when using jdbc (I am using com.microsoft.azure:spark-mssql-connector_2.12:1.3.0)?

With the createTableColumnTypes option one can specify spark types:

The database column data types to use instead of the defaults, when creating the table. Data type information should be specified in the same format as CREATE TABLE columns syntax (e.g: "name CHAR(64), comments VARCHAR(1024)"). The specified types should be valid spark sql data types.

But how do they translate to sql server types? I could not find a mapping table. For instance, float gets translated to REAL in sql server.

Also, is it possible to specify the target schema directly?


Solution

  • We can figure out how the type mapping is done by having a look at the source code. I'm not sure which Spark version you're on, so I'll assume you are on version 3.4.1 (the most recent version as I'm writing this).

    In the file called JdbcDialects.scala, we can find the following documentation:

    /**
     * ...
     *
     * Currently, the only thing done by the dialect is type mapping.
     * `getCatalystType` is used when reading from a JDBC table and `getJDBCType`
     * is used when writing to a JDBC table.  If `getCatalystType` returns `null`,
     * the default type handling is used for the given JDBC type.  Similarly,
     * if `getJDBCType` returns `(null, None)`, the default type handling is used
     * for the given Catalyst type.
     */
    

    So we understand that when writing a file (which is what you're interested in) some function called getJDBCType will be called. We also understand that if that function returns None some default type handling is used.

    We find this function for the SQL Server dialect in MsSqlServerDialect.scala:

      override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
        case TimestampType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP))
        case TimestampNTZType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP))
        case StringType => Some(JdbcType("NVARCHAR(MAX)", java.sql.Types.NVARCHAR))
        case BooleanType => Some(JdbcType("BIT", java.sql.Types.BIT))
        case BinaryType => Some(JdbcType("VARBINARY(MAX)", java.sql.Types.VARBINARY))
        case ShortType if !SQLConf.get.legacyMsSqlServerNumericMappingEnabled =>
          Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT))
        case _ => None
      }
    

    So we already have a bunch of mappings! For example, a StringType returns a NVARCHAR(MAX).

    But we see that there are a bunch of types we don't find in here: what about float for example? Well, those are handled in the getCommonJDBCType method:

      /**
       * Retrieve standard jdbc types.
       *
       * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
       * @return The default JdbcType for this DataType
       */
      def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
        dt match {
          case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))
          case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
          case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
          case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
          case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
          case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
          case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
          case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
          case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
          case CharType(n) => Option(JdbcType(s"CHAR($n)", java.sql.Types.CHAR))
          case VarcharType(n) => Option(JdbcType(s"VARCHAR($n)", java.sql.Types.VARCHAR))
          case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
          // This is a common case of timestamp without time zone. Most of the databases either only
          // support TIMESTAMP type or use TIMESTAMP as an alias for TIMESTAMP WITHOUT TIME ZONE.
          // Note that some dialects override this setting, e.g. as SQL Server.
          case TimestampNTZType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
          case DateType => Option(JdbcType("DATE", java.sql.Types.DATE))
          case t: DecimalType => Option(
            JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))
          case _ => None
        }
      }
    

    So now you know which of your Spark types get translated to which types in SQL Server! First have a look in the SQL Server specific getJDBCType method, if you can't find it in there you should find it in the getCommonJDBCTypes method.

    As you found out yourself, you can override this by using the createTableColumnTypes option when writing. An example can be found in datasource.py:

        jdbcDF.write \
            .option("createTableColumnTypes", "name CHAR(64), comments VARCHAR(1024)") \
            .jdbc("jdbc:postgresql:dbserver", "schema.tablename",
                  properties={"user": "username", "password": "password"})