This commit is contained in:
Kameron Carr 2024-09-10 14:41:55 -07:00 коммит произвёл Kameron Carr
Родитель 7f834eb56e
Коммит 62524f927b
2 изменённых файлов: 57 добавлений и 22 удалений

Просмотреть файл

@ -7,10 +7,14 @@ Provide a class for managing database connections and sessions
import logging
import os
import struct
import urllib
from contextlib import contextmanager
from typing import Any
import sqlalchemy
from sqlalchemy.engine.url import URL
from azure.identity import DefaultAzureCredential
from comma.database.model import Base, Distros, MonitoringSubjects
from comma.exceptions import CommaDatabaseError, CommaDataError
@ -28,42 +32,73 @@ class DatabaseDriver:
# Enable INFO-level logging when program is logging debug
# It's not ideal, because the messages are INFO level, but only enabled with debug
# As defined in msodbcsql.h
self.SQL_COPT_SS_ACCESS_TOKEN = 1256
if dry_run:
db_file = "comma.db"
LOGGER.info("Using local SQLite database at '%s'.", db_file)
engine = sqlalchemy.create_engine(f"sqlite:///{db_file}", echo=echo)
else:
LOGGER.info("Connecting to remote database...")
engine = sqlalchemy.create_engine(self._get_mssql_conn_str(), echo=echo)
engine = self.create_engine("ODBC Driver 17 for SQL Server")
Base.metadata.bind = engine
Base.metadata.create_all(engine)
self.session_factory = sqlalchemy.orm.sessionmaker(bind=engine)
@staticmethod
def _get_mssql_conn_str() -> sqlalchemy.engine.Engine:
"""
Create a connection string for MS SQL server and create engine instance
"""
def get_token(self) -> bytes:
try:
credential = DefaultAzureCredential()
cre_token = credential.get_token("https://database.windows.net/").token
token = cre_token.encode("utf-16-le")
token_struct = struct.pack(f"=I{len(token)}s", len(token), token)
return bytes(token_struct)
except Exception as e:
raise RuntimeError("Failed to obtain Azure AD token") from e
# Verify credentials are available
for envvar in ("COMMA_DB_URL", "COMMA_DB_NAME", "COMMA_DB_USERNAME", "COMMA_DB_PW"):
if not os.environ.get(envvar):
raise CommaDatabaseError(f"{envvar} is not defined in the current environment")
def create_engine(self, driver: str) -> Any:
if os.environ.get('COMMA_DB_USERNAME') and os.environ.get('COMMA_DB_PW'):
return self.create_engine_with_pass(driver)
else:
return self.create_engine_with_token(driver)
params = urllib.parse.quote_plus(
";".join(
(
"DRIVER={ODBC Driver 17 for SQL Server}",
f"SERVER={os.environ['COMMA_DB_URL']}",
f"DATABASE={os.environ['COMMA_DB_NAME']}",
f"UID={os.environ['COMMA_DB_USERNAME']}",
f"PWD={os.environ['COMMA_DB_PW']}",
)
)
def create_engine_with_pass(self, driver: str) -> Any:
try:
return sqlalchemy.create_engine(
URL(
drivername="mssql+pyodbc",
username=os.environ['COMMA_DB_USERNAME'],
password=os.environ['COMMA_DB_PW'],
host=os.environ['COMMA_DB_URL'],
database=os.environ['COMMA_DB_NAME'],
query={"driver": driver},
),
pool_recycle=300,
)
except Exception as e:
raise RuntimeError(
"Failed to create engine with username and password"
) from e
return f"mssql+pyodbc:///?odbc_connect={params}"
def create_engine_with_token(self, driver: str) -> Any:
try:
query = {
"odbc_connect": (
f"DRIVER={driver};DATABASE={os.environ['COMMA_DB_NAME']};"
f"SERVER={os.environ['COMMA_DB_URL']}"
)
}
connect_args = {
"attrs_before": {self.SQL_COPT_SS_ACCESS_TOKEN: self.get_token()}
}
return sqlalchemy.create_engine(
URL("mssql+pyodbc", query=query),
connect_args=connect_args,
pool_recycle=300,
)
except Exception as e:
raise RuntimeError("Failed to create engine with Azure AD token") from e
@contextmanager
def get_session(self) -> sqlalchemy.orm.session.Session:

Просмотреть файл

@ -22,7 +22,7 @@ dependencies = [
"openpyxl ~= 3.1.2",
"pydantic ~= 1.10.9",
"ruamel.yaml ~= 0.17.32",
"azure-identity ~= 1.17.1",
]
dynamic = ["version"]
license = {text = "MIT"}