I'm working on some data governance that will take the schema of an extracted query and validate it before doing any transformations on it. It's broken up in two classes. One class that will extract the data and another that will validate the data against an API. If the schema matches, then the we move on checking the quality of the data. However, my basic if
statement is always returning False
when the schemas match. I could use some assistance:
Code:
from pyspark.sql import SparkSession
from pyspark.sql.types import *
import json
class Session:
def __init__(self, config_file, inferred=True, header=True):
self.spark = self._create_session()
self.schema = inferred
self.header = header
self.config_file = config_file
self.connections = self.parse_file()
self.db_name = self.connections['db_name']
self.password = self.connections['password']
self.name = self.connections['username']
self.table_name = self.connections["tbl_name"]
# self._create_session()
@staticmethod
def _create_session():
spark = (
SparkSession \
.builder \
.appName("test") \
.config("spark.jars", SPARK_JARS) \
.getOrCreate()
)
return spark
def parse_file(self):
with open(self.config_file) as test:
data = json.load(test)
return data
def pull_data(self):
df = self.spark.read.format("jdbc") \
.option("url", f"jdbc:postgresql://localhost:5432/{self.db_name}") \
.option("dbtable", f'{self.table_name}') \
.option("user", f"{self.name}") \
.option("password", f"{self.password}") \
.option("driver", "org.postgresql.Driver") \
.load()
return df
DESIRED_SCHEMA = StructType([
StructField(name='order_id', dataType=IntegerType(), nullable=True),
StructField(name='order_date', dataType=TimestampType(), nullable=True),
StructField(name='order_customer_id', dataType=IntegerType(), nullable=True),
StructField(name='order_status', dataType=StringType(), nullable=True)
])
class SchemaCheck:
def __init__(self, targeted_schema=DESIRED_SCHEMA):
self.schema = targeted_schema
def val_schema(self, df):
if not df.schema == self.schema:
return False
return True
if __name__ == "__main__":
valid_schema = SchemaCheck()
spark = Session('connections.json')
my_df = spark.pull_data()
my_df.printSchema()
print(my_df.schema)
print(valid_schema.schema)
print(valid_schema.val_schema(my_df))
When everything is printed out, the results look like this:
root
|-- order_id: integer (nullable = true)
|-- order_date: timestamp (nullable = true)
|-- order_customer_id: integer (nullable = true)
|-- order_status: string (nullable = true)
StructType([StructField('order_id', IntegerType(), True), StructField('order_date', TimestampType(), True), StructField('order_customer_id', IntegerType(), True), StructField('order_status', StringType(), True)])
StructType([StructField('order_id', IntegerType(), True), StructField('order_date', TimestampType(), True), StructField('order_customer_id', IntegerType(), True), StructField('order_status', StringType(), True)])
False
I'm not sure why my return is False
. Could someone shed some light as to what I'm missing, please?
Comparing objects in python is not that easy. From the pyspark code I saw, there is no equality method defined in StructType class.
Basically, if you don't define an __eq__
method inside a class, then ==
when you check for equality between objects basically will do the same as is
, meaning checking for identity so it will check if the two objects are pointing to the same place, which is not the case for you right ? Then it will be False...
For your specific case, you could either compare field per field, because in the end, a StructType is just a Seq of StructFields, or just do something easy like that:
def val_schema(self, df):
return self.schema.simpleString() == df.schema.simpleString()
Note that this will not validate the nullability of the fields, but if you really want to verify it, then you can just go through all the fields to check it...
Hope it helps!
See more:
https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.types.StructType.html https://realpython.com/python-is-identity-vs-equality/