diff --git a/comma/database/driver.py b/comma/database/driver.py index ff07001..c16aab8 100755 --- a/comma/database/driver.py +++ b/comma/database/driver.py @@ -7,10 +7,14 @@ Provide a class for managing database connections and sessions import logging import os +import struct import urllib from contextlib import contextmanager +from typing import Any import sqlalchemy +from sqlalchemy.engine.url import URL +from azure.identity import DefaultAzureCredential from comma.database.model import Base, Distros, MonitoringSubjects from comma.exceptions import CommaDatabaseError, CommaDataError @@ -28,42 +32,73 @@ class DatabaseDriver: # Enable INFO-level logging when program is logging debug # It's not ideal, because the messages are INFO level, but only enabled with debug + # As defined in msodbcsql.h + self.SQL_COPT_SS_ACCESS_TOKEN = 1256 + if dry_run: db_file = "comma.db" LOGGER.info("Using local SQLite database at '%s'.", db_file) engine = sqlalchemy.create_engine(f"sqlite:///{db_file}", echo=echo) else: LOGGER.info("Connecting to remote database...") - engine = sqlalchemy.create_engine(self._get_mssql_conn_str(), echo=echo) + engine = self.create_engine("ODBC Driver 17 for SQL Server") Base.metadata.bind = engine Base.metadata.create_all(engine) self.session_factory = sqlalchemy.orm.sessionmaker(bind=engine) - @staticmethod - def _get_mssql_conn_str() -> sqlalchemy.engine.Engine: - """ - Create a connection string for MS SQL server and create engine instance - """ + def get_token(self) -> bytes: + try: + credential = DefaultAzureCredential() + cre_token = credential.get_token("https://database.windows.net/").token + token = cre_token.encode("utf-16-le") + token_struct = struct.pack(f"=I{len(token)}s", len(token), token) + return bytes(token_struct) + except Exception as e: + raise RuntimeError("Failed to obtain Azure AD token") from e - # Verify credentials are available - for envvar in ("COMMA_DB_URL", "COMMA_DB_NAME", "COMMA_DB_USERNAME", "COMMA_DB_PW"): - if not os.environ.get(envvar): - raise CommaDatabaseError(f"{envvar} is not defined in the current environment") + def create_engine(self, driver: str) -> Any: + if os.environ.get('COMMA_DB_USERNAME') and os.environ.get('COMMA_DB_PW'): + return self.create_engine_with_pass(driver) + else: + return self.create_engine_with_token(driver) - params = urllib.parse.quote_plus( - ";".join( - ( - "DRIVER={ODBC Driver 17 for SQL Server}", - f"SERVER={os.environ['COMMA_DB_URL']}", - f"DATABASE={os.environ['COMMA_DB_NAME']}", - f"UID={os.environ['COMMA_DB_USERNAME']}", - f"PWD={os.environ['COMMA_DB_PW']}", - ) + def create_engine_with_pass(self, driver: str) -> Any: + try: + return sqlalchemy.create_engine( + URL( + drivername="mssql+pyodbc", + username=os.environ['COMMA_DB_USERNAME'], + password=os.environ['COMMA_DB_PW'], + host=os.environ['COMMA_DB_URL'], + database=os.environ['COMMA_DB_NAME'], + query={"driver": driver}, + ), + pool_recycle=300, ) - ) + except Exception as e: + raise RuntimeError( + "Failed to create engine with username and password" + ) from e - return f"mssql+pyodbc:///?odbc_connect={params}" + def create_engine_with_token(self, driver: str) -> Any: + try: + query = { + "odbc_connect": ( + f"DRIVER={driver};DATABASE={os.environ['COMMA_DB_NAME']};" + f"SERVER={os.environ['COMMA_DB_URL']}" + ) + } + connect_args = { + "attrs_before": {self.SQL_COPT_SS_ACCESS_TOKEN: self.get_token()} + } + return sqlalchemy.create_engine( + URL("mssql+pyodbc", query=query), + connect_args=connect_args, + pool_recycle=300, + ) + except Exception as e: + raise RuntimeError("Failed to create engine with Azure AD token") from e @contextmanager def get_session(self) -> sqlalchemy.orm.session.Session: diff --git a/pyproject.toml b/pyproject.toml index f2afb2f..de59141 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "openpyxl ~= 3.1.2", "pydantic ~= 1.10.9", "ruamel.yaml ~= 0.17.32", - + "azure-identity ~= 1.17.1", ] dynamic = ["version"] license = {text = "MIT"}