Search code examples
pythonpysparkpyspark-schema

Comparing extracted dataframe Schema to targeted schema


I'm working on some data governance that will take the schema of an extracted query and validate it before doing any transformations on it. It's broken up in two classes. One class that will extract the data and another that will validate the data against an API. If the schema matches, then the we move on checking the quality of the data. However, my basic if statement is always returning False when the schemas match. I could use some assistance: Code:

from pyspark.sql import SparkSession
from pyspark.sql.types import *
import json


class Session:

    def __init__(self, config_file, inferred=True, header=True):

        self.spark = self._create_session()
        self.schema = inferred
        self.header = header
        self.config_file = config_file
        self.connections = self.parse_file()
        self.db_name = self.connections['db_name']
        self.password = self.connections['password']
        self.name = self.connections['username']
        self.table_name = self.connections["tbl_name"]
        # self._create_session()

    @staticmethod
    def _create_session():
        spark = (
            SparkSession \
                .builder \
                .appName("test") \
                .config("spark.jars", SPARK_JARS) \
                .getOrCreate()
        )
        return spark

    def parse_file(self):
        with open(self.config_file) as test:
            data = json.load(test)
            return data

    def pull_data(self):
        df = self.spark.read.format("jdbc") \
            .option("url", f"jdbc:postgresql://localhost:5432/{self.db_name}") \
            .option("dbtable", f'{self.table_name}') \
            .option("user", f"{self.name}") \
            .option("password", f"{self.password}") \
            .option("driver", "org.postgresql.Driver") \
            .load()
        return df

DESIRED_SCHEMA = StructType([
    StructField(name='order_id', dataType=IntegerType(), nullable=True),
    StructField(name='order_date', dataType=TimestampType(), nullable=True),
    StructField(name='order_customer_id', dataType=IntegerType(), nullable=True),
    StructField(name='order_status', dataType=StringType(), nullable=True)
])


class SchemaCheck:

    def __init__(self, targeted_schema=DESIRED_SCHEMA):
        self.schema = targeted_schema

    def val_schema(self, df):
        if not df.schema == self.schema:
            return False
        return True

if __name__ == "__main__":
    valid_schema = SchemaCheck()
    spark = Session('connections.json')
    my_df = spark.pull_data()
    my_df.printSchema()
    print(my_df.schema)
    print(valid_schema.schema)
    print(valid_schema.val_schema(my_df))

When everything is printed out, the results look like this:

root
 |-- order_id: integer (nullable = true)
 |-- order_date: timestamp (nullable = true)
 |-- order_customer_id: integer (nullable = true)
 |-- order_status: string (nullable = true)

StructType([StructField('order_id', IntegerType(), True), StructField('order_date', TimestampType(), True), StructField('order_customer_id', IntegerType(), True), StructField('order_status', StringType(), True)])
StructType([StructField('order_id', IntegerType(), True), StructField('order_date', TimestampType(), True), StructField('order_customer_id', IntegerType(), True), StructField('order_status', StringType(), True)])
False

I'm not sure why my return is False. Could someone shed some light as to what I'm missing, please?


Solution

  • Comparing objects in python is not that easy. From the pyspark code I saw, there is no equality method defined in StructType class.

    Basically, if you don't define an __eq__ method inside a class, then == when you check for equality between objects basically will do the same as is, meaning checking for identity so it will check if the two objects are pointing to the same place, which is not the case for you right ? Then it will be False...

    For your specific case, you could either compare field per field, because in the end, a StructType is just a Seq of StructFields, or just do something easy like that:

    def val_schema(self, df):
        return self.schema.simpleString() == df.schema.simpleString()
    

    Note that this will not validate the nullability of the fields, but if you really want to verify it, then you can just go through all the fields to check it...

    Hope it helps!

    See more:

    https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.types.StructType.html https://realpython.com/python-is-identity-vs-equality/