Search code examples
pythonazuresqlalchemyazure-sql-databasepyodbc

How to remove "Trusted_Connection=Yes" from pyodbc using SQLAlchemy Events and Azure Active Directory Token with classes?


The SQLAlchemy documentation shows how to remove the trusted connection part of the connection string using an event.listens_for decorator using an engine that isn't created inside a function or a class - here is the example:

import struct
from sqlalchemy import create_engine, event
from sqlalchemy.engine.url import URL
from azure import identity

SQL_COPT_SS_ACCESS_TOKEN = 1256  # Connection option for access tokens, as defined in msodbcsql.h
TOKEN_URL = "https://database.windows.net/"  # The token URL for any Azure SQL database

connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server"

engine = create_engine(connection_string)

azure_credentials = identity.DefaultAzureCredential()

@event.listens_for(engine, "do_connect")
def provide_token(dialect, conn_rec, cargs, cparams):
    # remove the "Trusted_Connection" parameter that SQLAlchemy adds
    cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "")

    # create token credential
    raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le")
    token_struct = struct.pack(f"<I{len(raw_token)}s", len(raw_token), raw_token)

    # apply it to keyword arguments
    cparams["attrs_before"] = {SQL_COPT_SS_ACCESS_TOKEN: token_struct}

The problem I am facing is that I connect to my database using a class called Db which has a .engine attribute created after instantiation (the engine that is created is dependent upon the environment). How can I use the decorator to remove the "Trusted_Connection=Yes" from my engine when it is an attribute of my class created after instantiation?

The code below results in the error saying: AttributeError: type object 'Db' has no attribute 'engine'

import os
import struct
from azure.identity import DefaultAzureCredential
from sqlalchemy.engine.url import URL
from sqlalchemy import create_engine, event
from sqlalchemy import inspect


class Db:
    def __init__(self, config: object) -> None:
        url = URL.create(
            drivername="mssql+pyodbc",
            port=1433,
            query=dict(driver='ODBC Driver 18 for SQL Server'),
            host=f"tcp:{os.environ.get(f'SERVER_{config.environment}')}",
            database=os.environ.get(f'DATABASE_{config.environment}')
        )
        self.engine = create_engine(url=url, connect_args={"autocommit": True})
        self.connection = self.engine.connect()
        self.inspector = inspect(subject=self.engine)

    def close(self) -> None:
        self.connection.close()


@event.listens_for(target=Db.engine, identifier="do_connect")
def provide_token(dialect, conn_rec, cargs, cparams):
    # remove the "Trusted_Connection" parameter that SQLAlchemy adds
    cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "")

    azure_credentials = identity.DefaultAzureCredential()
    raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le")
    token_struct = struct.pack(f"<I{len(raw_token)}s", len(raw_token), raw_token)

    # apply it to keyword arguments
    cparams["attrs_before"] = {1256: token_struct}

Solution

  • After digging further, you can use the listen method instead of listens_for from event and pass a callback function in addition to the target and identifier.

    Here is the solution:

    import os
    import struct
    from azure.identity import DefaultAzureCredential
    from sqlalchemy.engine.url import URL
    from sqlalchemy import create_engine, event
    from sqlalchemy import inspect
    
    
    class Db():
        def __init__(self, config: object) -> None:
            url = URL.create(
                drivername="mssql+pyodbc",
                port=1433,
                query=dict(driver='ODBC Driver 18 for SQL Server'),
                host=f"tcp:{os.environ.get(f'SERVER_{config.environment}')}",
                database=os.environ.get(f'DATABASE_{config.environment}')
            )
            self.engine = create_engine(url=url, connect_args={"autocommit": True})
            event.listen(target=self.engine, identifier="do_connect", fn=self._add_token)
            self.connection = self.engine.connect()
            self.inspector = inspect(subject=self.engine)
    
        def _add_token(self, dialect, conn_rec, cargs, cparams):
            azure_credentials = identity.DefaultAzureCredential()
            raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le")
            token_struct = struct.pack(f"<I{len(raw_token)}s", len(raw_token), raw_token)
            cparams["attrs_before"] = {1256: token_struct}
        
        def close(self) -> None:
            self.connection.close()