The SQLAlchemy documentation shows how to remove the trusted connection part of the connection string using an event.listens_for decorator using an engine that isn't created inside a function or a class - here is the example:
import struct
from sqlalchemy import create_engine, event
from sqlalchemy.engine.url import URL
from azure import identity
SQL_COPT_SS_ACCESS_TOKEN = 1256 # Connection option for access tokens, as defined in msodbcsql.h
TOKEN_URL = "https://database.windows.net/" # The token URL for any Azure SQL database
connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server"
engine = create_engine(connection_string)
azure_credentials = identity.DefaultAzureCredential()
@event.listens_for(engine, "do_connect")
def provide_token(dialect, conn_rec, cargs, cparams):
# remove the "Trusted_Connection" parameter that SQLAlchemy adds
cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "")
# create token credential
raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le")
token_struct = struct.pack(f"<I{len(raw_token)}s", len(raw_token), raw_token)
# apply it to keyword arguments
cparams["attrs_before"] = {SQL_COPT_SS_ACCESS_TOKEN: token_struct}
The problem I am facing is that I connect to my database using a class called Db
which has a .engine
attribute created after instantiation (the engine that is created is dependent upon the environment). How can I use the decorator to remove the "Trusted_Connection=Yes" from my engine when it is an attribute of my class created after instantiation?
The code below results in the error saying: AttributeError: type object 'Db' has no attribute 'engine'
import os
import struct
from azure.identity import DefaultAzureCredential
from sqlalchemy.engine.url import URL
from sqlalchemy import create_engine, event
from sqlalchemy import inspect
class Db:
def __init__(self, config: object) -> None:
url = URL.create(
drivername="mssql+pyodbc",
port=1433,
query=dict(driver='ODBC Driver 18 for SQL Server'),
host=f"tcp:{os.environ.get(f'SERVER_{config.environment}')}",
database=os.environ.get(f'DATABASE_{config.environment}')
)
self.engine = create_engine(url=url, connect_args={"autocommit": True})
self.connection = self.engine.connect()
self.inspector = inspect(subject=self.engine)
def close(self) -> None:
self.connection.close()
@event.listens_for(target=Db.engine, identifier="do_connect")
def provide_token(dialect, conn_rec, cargs, cparams):
# remove the "Trusted_Connection" parameter that SQLAlchemy adds
cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "")
azure_credentials = identity.DefaultAzureCredential()
raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le")
token_struct = struct.pack(f"<I{len(raw_token)}s", len(raw_token), raw_token)
# apply it to keyword arguments
cparams["attrs_before"] = {1256: token_struct}
After digging further, you can use the listen
method instead of listens_for
from event
and pass a callback function in addition to the target
and identifier
.
Here is the solution:
import os
import struct
from azure.identity import DefaultAzureCredential
from sqlalchemy.engine.url import URL
from sqlalchemy import create_engine, event
from sqlalchemy import inspect
class Db():
def __init__(self, config: object) -> None:
url = URL.create(
drivername="mssql+pyodbc",
port=1433,
query=dict(driver='ODBC Driver 18 for SQL Server'),
host=f"tcp:{os.environ.get(f'SERVER_{config.environment}')}",
database=os.environ.get(f'DATABASE_{config.environment}')
)
self.engine = create_engine(url=url, connect_args={"autocommit": True})
event.listen(target=self.engine, identifier="do_connect", fn=self._add_token)
self.connection = self.engine.connect()
self.inspector = inspect(subject=self.engine)
def _add_token(self, dialect, conn_rec, cargs, cparams):
azure_credentials = identity.DefaultAzureCredential()
raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le")
token_struct = struct.pack(f"<I{len(raw_token)}s", len(raw_token), raw_token)
cparams["attrs_before"] = {1256: token_struct}
def close(self) -> None:
self.connection.close()