Allow Azure AD login
This commit is contained in:
Родитель
7f834eb56e
Коммит
62524f927b
|
@ -7,10 +7,14 @@ Provide a class for managing database connections and sessions
|
|||
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
import urllib
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.engine.url import URL
|
||||
from azure.identity import DefaultAzureCredential
|
||||
|
||||
from comma.database.model import Base, Distros, MonitoringSubjects
|
||||
from comma.exceptions import CommaDatabaseError, CommaDataError
|
||||
|
@ -28,42 +32,73 @@ class DatabaseDriver:
|
|||
# Enable INFO-level logging when program is logging debug
|
||||
# It's not ideal, because the messages are INFO level, but only enabled with debug
|
||||
|
||||
# As defined in msodbcsql.h
|
||||
self.SQL_COPT_SS_ACCESS_TOKEN = 1256
|
||||
|
||||
if dry_run:
|
||||
db_file = "comma.db"
|
||||
LOGGER.info("Using local SQLite database at '%s'.", db_file)
|
||||
engine = sqlalchemy.create_engine(f"sqlite:///{db_file}", echo=echo)
|
||||
else:
|
||||
LOGGER.info("Connecting to remote database...")
|
||||
engine = sqlalchemy.create_engine(self._get_mssql_conn_str(), echo=echo)
|
||||
engine = self.create_engine("ODBC Driver 17 for SQL Server")
|
||||
|
||||
Base.metadata.bind = engine
|
||||
Base.metadata.create_all(engine)
|
||||
self.session_factory = sqlalchemy.orm.sessionmaker(bind=engine)
|
||||
|
||||
@staticmethod
|
||||
def _get_mssql_conn_str() -> sqlalchemy.engine.Engine:
|
||||
"""
|
||||
Create a connection string for MS SQL server and create engine instance
|
||||
"""
|
||||
def get_token(self) -> bytes:
|
||||
try:
|
||||
credential = DefaultAzureCredential()
|
||||
cre_token = credential.get_token("https://database.windows.net/").token
|
||||
token = cre_token.encode("utf-16-le")
|
||||
token_struct = struct.pack(f"=I{len(token)}s", len(token), token)
|
||||
return bytes(token_struct)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to obtain Azure AD token") from e
|
||||
|
||||
# Verify credentials are available
|
||||
for envvar in ("COMMA_DB_URL", "COMMA_DB_NAME", "COMMA_DB_USERNAME", "COMMA_DB_PW"):
|
||||
if not os.environ.get(envvar):
|
||||
raise CommaDatabaseError(f"{envvar} is not defined in the current environment")
|
||||
def create_engine(self, driver: str) -> Any:
|
||||
if os.environ.get('COMMA_DB_USERNAME') and os.environ.get('COMMA_DB_PW'):
|
||||
return self.create_engine_with_pass(driver)
|
||||
else:
|
||||
return self.create_engine_with_token(driver)
|
||||
|
||||
params = urllib.parse.quote_plus(
|
||||
";".join(
|
||||
(
|
||||
"DRIVER={ODBC Driver 17 for SQL Server}",
|
||||
f"SERVER={os.environ['COMMA_DB_URL']}",
|
||||
f"DATABASE={os.environ['COMMA_DB_NAME']}",
|
||||
f"UID={os.environ['COMMA_DB_USERNAME']}",
|
||||
f"PWD={os.environ['COMMA_DB_PW']}",
|
||||
)
|
||||
)
|
||||
def create_engine_with_pass(self, driver: str) -> Any:
|
||||
try:
|
||||
return sqlalchemy.create_engine(
|
||||
URL(
|
||||
drivername="mssql+pyodbc",
|
||||
username=os.environ['COMMA_DB_USERNAME'],
|
||||
password=os.environ['COMMA_DB_PW'],
|
||||
host=os.environ['COMMA_DB_URL'],
|
||||
database=os.environ['COMMA_DB_NAME'],
|
||||
query={"driver": driver},
|
||||
),
|
||||
pool_recycle=300,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Failed to create engine with username and password"
|
||||
) from e
|
||||
|
||||
return f"mssql+pyodbc:///?odbc_connect={params}"
|
||||
def create_engine_with_token(self, driver: str) -> Any:
|
||||
try:
|
||||
query = {
|
||||
"odbc_connect": (
|
||||
f"DRIVER={driver};DATABASE={os.environ['COMMA_DB_NAME']};"
|
||||
f"SERVER={os.environ['COMMA_DB_URL']}"
|
||||
)
|
||||
}
|
||||
connect_args = {
|
||||
"attrs_before": {self.SQL_COPT_SS_ACCESS_TOKEN: self.get_token()}
|
||||
}
|
||||
return sqlalchemy.create_engine(
|
||||
URL("mssql+pyodbc", query=query),
|
||||
connect_args=connect_args,
|
||||
pool_recycle=300,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to create engine with Azure AD token") from e
|
||||
|
||||
@contextmanager
|
||||
def get_session(self) -> sqlalchemy.orm.session.Session:
|
||||
|
|
|
@ -22,7 +22,7 @@ dependencies = [
|
|||
"openpyxl ~= 3.1.2",
|
||||
"pydantic ~= 1.10.9",
|
||||
"ruamel.yaml ~= 0.17.32",
|
||||
|
||||
"azure-identity ~= 1.17.1",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
license = {text = "MIT"}
|
||||
|
|
Загрузка…
Ссылка в новой задаче