Remove direct calls to config from database driver

This commit is contained in:
Avram Lubkin 2023-06-19 16:04:21 -04:00 коммит произвёл Avram Lubkin
Родитель 1edd645781
Коммит 17af657e66
2 изменённых файлов: 22 добавлений и 27 удалений

Просмотреть файл

@ -21,13 +21,13 @@ from comma.util.tracking import get_linux_repo
LOGGER = logging.getLogger("comma.cli")
def run(options):
def run(options, database):
"""
Handle run subcommand
"""
if options.dry_run:
with DatabaseDriver.get_session() as session:
with database.get_session() as session:
if session.query(Distros).first() is None:
session.add_all(config.default_distros)
if session.query(MonitoringSubjects).first() is None:
@ -39,12 +39,12 @@ def run(options):
if options.upstream:
LOGGER.info("Begin monitoring upstream")
Upstream(config, DatabaseDriver()).process_commits()
Upstream(config, database).process_commits()
LOGGER.info("Finishing monitoring upstream")
if options.downstream:
LOGGER.info("Begin monitoring downstream")
Downstream(config, DatabaseDriver()).monitor_downstream()
Downstream(config, database).monitor_downstream()
LOGGER.info("Finishing monitoring downstream")
@ -55,6 +55,7 @@ def main(args: Optional[Sequence[str]] = None):
options = parse_args(args)
# Configure logging
logging.basicConfig(
level={0: logging.WARNING, 1: logging.INFO}.get(options.verbose, logging.DEBUG),
format="%(asctime)s %(name)-5s %(levelname)-7s %(message)s",
@ -67,14 +68,16 @@ def main(args: Optional[Sequence[str]] = None):
# TODO(Issue 25: resolve configuration
database = DatabaseDriver(dry_run=options.dry_run, echo=options.verbose > 2)
if options.subcommand == "symbols":
missing = Symbols(config, DatabaseDriver()).get_missing_commits(options.file)
missing = Symbols(config, database).get_missing_commits(options.file)
print("Missing symbols from:")
for commit in missing:
print(f" {commit}")
if options.subcommand == "downstream":
downstream = Downstream(config, DatabaseDriver())
downstream = Downstream(config, database)
# Print current targets in database
if options.action in {"list", None}:
downstream.list_targets()
@ -84,10 +87,10 @@ def main(args: Optional[Sequence[str]] = None):
downstream.add_target(options.name, options.url, options.revision)
if options.subcommand == "run":
run(options)
run(options, database)
if options.subcommand == "spreadsheet":
spreadsheet = Spreadsheet(config, DatabaseDriver())
spreadsheet = Spreadsheet(config, database)
if args.export_commits:
spreadsheet.export_commits(args.in_file, args.out_file)
if args.update_commits:

Просмотреть файл

@ -14,7 +14,6 @@ from contextlib import contextmanager
import sqlalchemy
from comma.database.model import Base, MonitoringSubjects
from comma.util import config
LOGGER = logging.getLogger(__name__)
@ -25,24 +24,24 @@ class DatabaseDriver:
Database driver managing connections
"""
_instance = None
def __init__(self, dry_run, echo=False):
# Enable INFO-level logging when program is logging debug
# It's not ideal, because the messages are INFO level, but only enabled with debug
def __init__(self):
if config.dry_run:
if dry_run:
db_file = "comma.db"
LOGGER.info("Using local SQLite database at '%s'.", db_file)
engine = sqlalchemy.create_engine(f"sqlite:///{db_file}", echo=config.verbose > 2)
engine = sqlalchemy.create_engine(f"sqlite:///{db_file}", echo=echo)
else:
LOGGER.info("Connecting to remote database...")
engine = self._get_mssql_engine()
LOGGER.info("Connected!")
engine = sqlalchemy.create_engine(self._get_mssql_conn_str(), echo=echo)
Base.metadata.bind = engine
Base.metadata.create_all(engine)
self.session = sqlalchemy.orm.sessionmaker(bind=engine)
self.session_factory = sqlalchemy.orm.sessionmaker(bind=engine)
@staticmethod
def _get_mssql_engine() -> sqlalchemy.engine.Engine:
def _get_mssql_conn_str() -> sqlalchemy.engine.Engine:
"""
Create a connection string for MS SQL server and create engine instance
"""
@ -64,22 +63,15 @@ class DatabaseDriver:
)
)
return sqlalchemy.create_engine(
f"mssql+pyodbc:///?odbc_connect={params}",
echo=(config.verbose > 2),
)
return f"mssql+pyodbc:///?odbc_connect={params}"
@classmethod
@contextmanager
def get_session(cls) -> sqlalchemy.orm.session.Session:
def get_session(self) -> sqlalchemy.orm.session.Session:
"""
Context manager for getting a database session
"""
# Only support a single instance
if cls._instance is None:
cls._instance = cls()
session = cls._instance.session()
session = self.session_factory()
try:
yield session