add db setup/teardown to py.test suite

This commit is contained in:
mdoglio 2013-03-19 19:03:26 +00:00
Родитель 8151206a2f
Коммит eb11feeace
4 изменённых файлов: 216 добавлений и 7 удалений

1
runtests.sh Normal file → Executable file
Просмотреть файл

@ -1,3 +1,2 @@
#!/bin/sh
py.test tests/$* --cov-report html --cov treeherder

Просмотреть файл

@ -1,5 +1,8 @@
import os
from os.path import dirname
from django.core.management import call_command
import sys
import pytest
def pytest_sessionstart(session):
"""
@ -8,10 +11,12 @@ Set up the test environment.
Set DJANGO_SETTINGS_MODULE and sets up a test database.
"""
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "treeherder.settings.base")
sys.path.append(dirname(dirname(__file__)))
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "treeherder.settings")
from django.conf import settings
from django.test.simple import DjangoTestSuiteRunner
from treeherder.webapp.models import Datasource
# we don't actually let Django run the tests, but we need to use some
# methods of its runner for setup/teardown of dbs and some other things
session.django_runner = DjangoTestSuiteRunner()
@ -19,15 +24,21 @@ Set DJANGO_SETTINGS_MODULE and sets up a test database.
session.django_runner.setup_test_environment()
# support custom db prefix for tests for the main datazilla datasource
# as well as for the testproj and testpushlog dbs
DB_USER = "myuser"
DB_PASS = "mypass"
settings.DATABASES["default"]["USER"] = DB_USER
settings.DATABASES["default"]["PASSWORD"] = DB_PASS
prefix = getattr(settings, "TEST_DB_PREFIX", "")
settings.DATABASES["default"]["TEST_NAME"] = "{0}test_treeherder".format(prefix)
# this sets up a clean test-only database
session.django_db_config = session.django_runner.setup_databases()
# init the datasource db
call_command("init_master_db", interactive=False)
def pytest_sessionfinish(session):
"""Tear down the test environment, including databases."""
print("\n")
from treeherder.webapp.models import Datasource
session.django_runner.teardown_databases(session.django_db_config)
session.django_runner.teardown_test_environment()
@ -55,11 +66,16 @@ def pytest_runtest_teardown(item):
"""
Per-test teardown.
Roll back the Django ORM transaction
Roll back the Django ORM transaction and delete all the dbs created between tests
"""
from django.test.testcases import restore_transaction_methods
from django.db import transaction
from treeherder.webapp.models import Datasource
ds_list = Datasource.objects.all()
for ds in ds_list:
ds.delete()
restore_transaction_methods()
transaction.rollback()

44
tests/test_setup.py Normal file
Просмотреть файл

@ -0,0 +1,44 @@
import pytest
from django.conf import settings
from treeherder.webapp.models import Datasource
import MySQLdb
@pytest.fixture
def jobs_ds():
prefix = getattr(settings, "TEST_DB_PREFIX", "")
return Datasource.objects.create(
project="{0}test_myproject".format(prefix),
dataset=1,
contenttype="jobs",
host="localhost",
)
@pytest.fixture
def objectstore_ds():
prefix = getattr(settings, "TEST_DB_PREFIX", "")
return Datasource.objects.create(
project="{0}test_myproject".format(prefix),
dataset=1,
contenttype="objectstore",
host="localhost",
)
@pytest.fixture
def db_conn():
return MySQLdb.connect(
host="localhost",
user=settings.DATABASES['default']['USER'],
passwd=settings.DATABASES['default']['PASSWORD'],
)
def test_datasource_db_created(jobs_ds, db_conn):
cur = db_conn.cursor()
cur.execute("SHOW DATABASES;")
rows = cur.fetchall()
assert jobs_ds.name in [r[0] for r in rows], \
"When a datasource is created, a new db should be created too"
db_conn.close()

Просмотреть файл

@ -1,5 +1,8 @@
from __future__ import unicode_literals
from django.db import models
import uuid
import subprocess
from treeherder import path
class Product(models.Model):
@ -149,7 +152,7 @@ class Datasource(models.Model):
type = models.CharField(max_length=25L)
oauth_consumer_key = models.CharField(max_length=45L, blank=True)
oauth_consumer_secret = models.CharField(max_length=45L, blank=True)
creation_date = models.DateTimeField()
creation_date = models.DateTimeField(auto_now_add=True)
cron_batch = models.CharField(max_length=45L, blank=True)
class Meta:
@ -163,6 +166,153 @@ class Datasource(models.Model):
return "{0} ({1})".format(
self.name, self.project)
def save(self, *args, **kwargs):
inserting = not self.pk
if inserting:
if not self.name:
self.name = "{0}_{1}_{2}".format(
self.project,
self.contenttype,
self.dataset
)
if not self.type:
self.type = "mysql"
self.oauth_consumer_key = None
self.oauth_consumer_secret = None
if self.contenttype == 'objectstore':
self.oauth_consumer_key = uuid.uuid4()
self.oauth_consumer_secret = uuid.uuid4()
super(Datasource, self).save(*args, **kwargs)
if inserting:
self.create_db()
def create_db(self, schema_file=None):
"""
Create the database for this source, using given SQL schema file.
If schema file is not given, defaults to
"template_schema/schema_<contenttype>.sql.tmpl".
Assumes that the database server at ``self.host`` is accessible, and
that ``DATABASE_USER`` (identified by
``DATABASE_PASSWORD`` exists on it and has permissions to
create databases.
"""
from django.conf import settings
import MySQLdb
DB_USER = settings.DATABASES["default"]["USER"]
DB_PASS = settings.DATABASES["default"]["PASSWORD"]
if self.type.lower().startswith("mysql-"):
engine = self.type[len("mysql-"):]
elif self.type.lower() == "mysql":
engine = "InnoDB"
else:
raise NotImplementedError(
"Currently only MySQL data source is supported.")
if schema_file is None:
schema_file = path(
"model",
"sql",
"template_schema",
"project_{0}_1.sql.tmpl".format(self.contenttype),
)
conn = MySQLdb.connect(
host=self.host,
user=DB_USER,
passwd=DB_PASS,
)
cur = conn.cursor()
cur.execute("CREATE DATABASE {0}".format(self.name))
conn.close()
# MySQLdb provides no way to execute an entire SQL file in bulk, so we
# have to shell out to the commandline client.
with open(schema_file) as f:
# set the engine to use
sql = f.read().format(engine=engine)
args = [
"mysql",
"--host={0}".format(self.host),
"--user={0}".format(DB_USER),
]
if DB_PASS:
args.append(
"--password={0}".format(
DB_PASS)
)
args.append(self.name)
proc = subprocess.Popen(
args,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
(output, _) = proc.communicate(sql)
if proc.returncode:
raise IOError(
"Unable to set up schema for datasource {0}: "
"mysql returned code {1}, output follows:\n\n{2}".format(
self.key, proc.returncode, output
)
)
def delete_db(self):
from django.conf import settings
import MySQLdb
DB_USER = settings.DATABASES["default"]["USER"]
DB_PASS = settings.DATABASES["default"]["PASSWORD"]
conn = MySQLdb.connect(
host=self.host,
user=DB_USER,
passwd=DB_PASS,
)
cur = conn.cursor()
cur.execute("DROP DATABASE {0}".format(self.name))
conn.close()
def delete(self, *args, **kwargs):
self.delete_db()
super(Datasource, self).delete(*args, **kwargs)
def truncate(self, skip_list=None):
"""
Truncate all tables in the db self refers to.
Skip_list is a list of table names to skip truncation.
"""
from django.conf import settings
import MySQLdb
skip_list = set(skip_list or [])
DB_USER = settings.DATABASES["default"]["USER"]
DB_PASS = settings.DATABASES["default"]["PASSWORD"]
conn = MySQLdb.connect(
host=self.host,
user=DB_USER,
passwd=DB_PASS,
db=self.name,
)
cur = conn.cursor()
cur.execute("SET FOREIGN_KEY_CHECKS = 0")
cur.execute("SHOW TABLES")
for table, in cur.fetchall():
# if there is a skip_list, then skip any table with matching name
if table.lower() not in skip_list:
# needed to use backticks around table name, because if the
# table name is a keyword (like "option") then this will fail
cur.execute("TRUNCATE TABLE `{0}`".format(table))
cur.execute("SET FOREIGN_KEY_CHECKS = 1")
conn.close()
class JobGroup(models.Model):
id = models.IntegerField(primary_key=True)