Adding initial unit tests for PGSMO (#49)

* Adding unit tests for querying helper
Renaming ConnectionWrapper to ServerConnection (a la SMO)
Making execute methods be methods on the ServerConnection

* Adding tests for templating, except pgAdmin code

* Adding tests for ScanKeywordExtraLookup

* Some more tests...

* Fixing more merge issues

* Fixing nonsense with conflicting package names

* Adding unittests for column class

* Adding unittests for database objects

* Adding unittests for role objects

* Adding unittests for schema objects

* Refactoring out init test which will be reused

* Refactoring out init tests to make it easier to reuse the validation code

* Adding tests for table class

* Adding tests for tablespace class

* Adding tests for view class

* Adding tests for server class

* Flake8 stuff and things

* Fixing bug from merge...

* Changes as per PR comments
This commit is contained in:
Benjamin Russell 2017-07-06 16:02:32 -07:00 коммит произвёл GitHub
Родитель c63502f826
Коммит 3cee508796
26 изменённых файлов: 1256 добавлений и 238 удалений

Просмотреть файл

@ -6,18 +6,19 @@
from typing import List, Optional
import pgsmo.objects.node_object as node
import pgsmo.utils as utils
import pgsmo.utils.querying as querying
import pgsmo.utils.templating as templating
TEMPLATE_ROOT = utils.templating.get_template_root(__file__, 'templates')
TEMPLATE_ROOT = templating.get_template_root(__file__, 'templates')
class Column(node.NodeObject):
@classmethod
def get_nodes_for_parent(cls, conn: utils.querying.ConnectionWrapper, tid: int) -> List['Column']:
def get_nodes_for_parent(cls, conn: querying.ServerConnection, tid: int) -> List['Column']:
return node.get_nodes(conn, TEMPLATE_ROOT, cls._from_node_query, tid=tid)
@classmethod
def _from_node_query(cls, conn: utils.querying.ConnectionWrapper, **kwargs) -> 'Column':
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs) -> 'Column':
"""
Creates a new Column object based on the the results from the column nodes query
:param conn: Connection used to execute the column nodes query
@ -37,7 +38,7 @@ class Column(node.NodeObject):
return col
def __init__(self, conn: utils.querying.ConnectionWrapper, name: str, datatype: str):
def __init__(self, conn: querying.ServerConnection, name: str, datatype: str):
"""
Initializes a new instance of a Column
:param conn: Connection to the server/database that this object will belong to
@ -63,11 +64,3 @@ class Column(node.NodeObject):
@property
def not_null(self) -> Optional[bool]:
return self._not_null
# METHODS ##############################################################
def refresh(self):
self._fetch_properties()
# IMPLEMENTATION DETAILS ###############################################
def _fetch_properties(self):
pass

Просмотреть файл

@ -7,18 +7,19 @@ from typing import List, Optional # noqa
import pgsmo.objects.node_object as node
from pgsmo.objects.schema.schema import Schema
import pgsmo.utils as utils
import pgsmo.utils.querying as querying
import pgsmo.utils.templating as templating
TEMPLATE_ROOT = utils.templating.get_template_root(__file__, 'templates')
TEMPLATE_ROOT = templating.get_template_root(__file__, 'templates')
class Database(node.NodeObject):
@classmethod
def get_nodes_for_parent(cls, conn: utils.querying.ConnectionWrapper) -> List['Database']:
def get_nodes_for_parent(cls, conn: querying.ServerConnection) -> List['Database']:
return node.get_nodes(conn, TEMPLATE_ROOT, cls._from_node_query, last_system_oid=0)
@classmethod
def _from_node_query(cls, conn: utils.querying.ConnectionWrapper, **kwargs) -> 'Database':
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs) -> 'Database':
"""
Creates a new Database object based on the results from a query to lookup databases
:param conn: Connection used to generate the db info query
@ -34,7 +35,6 @@ class Database(node.NodeObject):
"""
db = cls(conn, kwargs['name'])
db._oid = kwargs['did']
db._is_connected = kwargs['name'] == conn.dsn_parameters.get('dbname')
db._tablespace = kwargs['spcname']
db._allow_conn = kwargs['datallowconn']
db._can_create = kwargs['cancreate']
@ -42,13 +42,13 @@ class Database(node.NodeObject):
return db
def __init__(self, conn: utils.querying.ConnectionWrapper, name: str):
def __init__(self, conn: querying.ServerConnection, name: str):
"""
Initializes a new instance of a database
:param name: Name of the database
"""
super(Database, self).__init__(conn, name)
self._is_connected: bool = False
self._is_connected: bool = conn.dsn_parameters.get('dbname') == name
# Declare the optional parameters
self._tablespace: Optional[str] = None
@ -57,7 +57,9 @@ class Database(node.NodeObject):
self._owner_oid: Optional[int] = None
# Declare the child items
self._schemas: node.NodeCollection = node.NodeCollection(lambda: Schema.get_nodes_for_parent(self._conn))
self._schemas: Optional[node.NodeCollection] = None if not self._is_connected else node.NodeCollection(
lambda: Schema.get_nodes_for_parent(conn)
)
# PROPERTIES ###########################################################
# TODO: Create setters for optional values
@ -80,11 +82,6 @@ class Database(node.NodeObject):
return self._schemas
# METHODS ##############################################################
def refresh(self):
self._fetch_properties()
self._schemas.reset()
def create(self):
pass
@ -94,6 +91,7 @@ class Database(node.NodeObject):
def delete(self):
pass
# IMPLEMENTATION DETAILS ###############################################
def _fetch_properties(self):
pass
def refresh(self):
"""Resets the internal collection of child objects"""
if self._schemas is not None:
self._schemas.reset()

Просмотреть файл

@ -14,12 +14,12 @@ import pgsmo.utils.querying as querying
class NodeObject:
@classmethod
@abstractmethod
def _from_node_query(cls, conn: querying.ConnectionWrapper, **kwargs):
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs):
pass
def __init__(self, conn: querying.ConnectionWrapper, name: str):
def __init__(self, conn: querying.ServerConnection, name: str):
# Define the state of the object
self._conn: querying.ConnectionWrapper = conn
self._conn: querying.ServerConnection = conn
# Declare node basic properties
self._name: str = name
@ -92,9 +92,9 @@ class NodeCollection:
T = TypeVar('T')
def get_nodes(conn: querying.ConnectionWrapper,
def get_nodes(conn: querying.ServerConnection,
template_root: str,
generator: Callable[[type, querying.ConnectionWrapper, Dict[str, any]], T],
generator: Callable[[type, querying.ServerConnection, Dict[str, any]], T],
**kwargs) -> List[T]:
"""
Renders and executes nodes.sql for the given database version to generate a list of NodeObjects
@ -108,6 +108,6 @@ def get_nodes(conn: querying.ConnectionWrapper,
templating.get_template_path(template_root, 'nodes.sql', conn.version),
**kwargs
)
cols, rows = querying.execute_dict(conn, sql)
cols, rows = conn.execute_dict(sql)
return [generator(conn, **row) for row in rows]

Просмотреть файл

@ -6,15 +6,16 @@
from typing import List, Optional
from pgsmo.objects.node_object import NodeObject, get_nodes
import pgsmo.utils as utils
import pgsmo.utils.querying as querying
import pgsmo.utils.templating as templating
TEMPLATE_ROOT = utils.templating.get_template_root(__file__, 'templates')
TEMPLATE_ROOT = templating.get_template_root(__file__, 'templates')
class Role(NodeObject):
@classmethod
def get_nodes_for_parent(cls, conn: utils.querying.ConnectionWrapper) -> List['Role']:
def get_nodes_for_parent(cls, conn: querying.ServerConnection) -> List['Role']:
"""
Generates a list of roles for a given server. Intended to only be called by a Server object
:param conn: Connection to use to look up the roles for the server
@ -23,7 +24,7 @@ class Role(NodeObject):
return get_nodes(conn, TEMPLATE_ROOT, cls._from_node_query)
@classmethod
def _from_node_query(cls, conn: utils.querying.ConnectionWrapper, **kwargs) -> 'Role':
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs) -> 'Role':
"""
Creates a Role object from the result of a role node query
:param conn: Connection that executed the role node query
@ -39,7 +40,7 @@ class Role(NodeObject):
return role
def __init__(self, conn: utils.querying.ConnectionWrapper, name: str):
def __init__(self, conn: querying.ServerConnection, name: str):
"""
Initializes internal state of a Role object
:param conn: Connection that executed the role node query

Просмотреть файл

@ -9,20 +9,21 @@ from typing import List, Optional
import pgsmo.objects.node_object as node
from pgsmo.objects.table.table import Table
from pgsmo.objects.view.view import View
import pgsmo.utils as utils
import pgsmo.utils.querying as querying
import pgsmo.utils.templating as templating
TEMPLATE_ROOT = utils.templating.get_template_root(__file__, 'templates')
TEMPLATE_ROOT = templating.get_template_root(__file__, 'templates')
class Schema(node.NodeObject):
@classmethod
def get_nodes_for_parent(cls, conn: utils.querying.ConnectionWrapper) -> List['Schema']:
def get_nodes_for_parent(cls, conn: querying.ServerConnection) -> List['Schema']:
type_template_root = path.join(TEMPLATE_ROOT, conn.server_type)
return node.get_nodes(conn, type_template_root, cls._from_node_query)
@classmethod
def _from_node_query(cls, conn: utils.querying.ConnectionWrapper, **kwargs) -> 'Schema':
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs) -> 'Schema':
"""
Creates an instance of a schema object from the results of a nodes query
:param conn: The connection used to execute the nodes query
@ -41,7 +42,7 @@ class Schema(node.NodeObject):
return schema
def __init__(self, conn: utils.querying.ConnectionWrapper, name: str):
def __init__(self, conn: querying.ServerConnection, name: str):
super(Schema, self).__init__(conn, name)
# Declare the optional parameters
@ -53,7 +54,7 @@ class Schema(node.NodeObject):
lambda: Table.get_nodes_for_parent(self._conn, self._oid)
)
self._views: node.NodeCollection = node.NodeCollection(
lambda: View.get_nodes_for_parent(self._conn, self.oid)
lambda: View.get_nodes_for_parent(self._conn, self._oid)
)
# PROPERTIES ###########################################################
@ -76,5 +77,6 @@ class Schema(node.NodeObject):
# METHODS ##############################################################
def refresh(self) -> None:
self._tables = Table.get_tables_for_schema(self._conn, self._oid)
self._views = View.get_views_for_schema(self._conn, self._oid)
"""Resets the internal collections of child objects"""
self._tables.reset()
self._views.reset()

Просмотреть файл

@ -25,10 +25,10 @@ class Server:
:param conn: psycopg2 connection
"""
# Everything we know about the server will be based on the connection
self._conn = utils.querying.ConnectionWrapper(conn)
self._conn = utils.querying.ServerConnection(conn)
# Declare the server properties
props = self._conn.connection.get_dsn_parameters()
props = self._conn.dsn_parameters
self._host: str = props['host']
self._port: int = int(props['port'])
self._maintenance_db: str = props['dbname']
@ -55,8 +55,8 @@ class Server:
return self._host
@property
def in_recovery(self) -> bool:
"""Whether or not the server is in recovery mode"""
def in_recovery(self) -> Optional[bool]:
"""Whether or not the server is in recovery mode. If None, value was not loaded from server"""
return self._in_recovery
@property
@ -75,8 +75,8 @@ class Server:
return self._conn.version
@property
def wal_paused(self) -> bool:
"""Whether or not the Write-Ahead Log (WAL) is paused"""
def wal_paused(self) -> Optional[bool]:
"""Whether or not the Write-Ahead Log (WAL) is paused. If None, value was not loaded from server"""
return self._wal_paused
# -CHILD OBJECTS #######################################################
@ -98,15 +98,17 @@ class Server:
# METHODS ##############################################################
# IMPLEMENTATION DETAILS ###############################################
def _fetch_recovery_state(self) -> None:
recovery_check_sql = utils.templating.render_template(
utils.templating.get_template_path(TEMPLATE_ROOT, 'check_recovery.sql', self._conn.version)
)
cols, rows = utils.querying.execute_dict(self._conn, recovery_check_sql)
if len(rows) > 0:
self._in_recovery = rows[0]['inrecovery']
self._wal_paused = rows[0]['isreplaypaused']
else:
self._in_recovery = None
self._wal_paused = None
# Commenting out until support for extended properties is added
# See https://github.com/Microsoft/carbon/issues/1342
# def _fetch_recovery_state(self) -> None:
# recovery_check_sql = utils.templating.render_template(
# utils.templating.get_template_path(TEMPLATE_ROOT, 'check_recovery.sql', self._conn.version)
# )
#
# cols, rows = self._conn.execute_dict(recovery_check_sql)
# if len(rows) > 0:
# self._in_recovery = rows[0]['inrecovery']
# self._wal_paused = rows[0]['isreplaypaused']
# else:
# self._in_recovery = None
# self._wal_paused = None

Просмотреть файл

@ -7,19 +7,20 @@ from typing import List
from pgsmo.objects.column.column import Column
import pgsmo.objects.node_object as node
import pgsmo.utils as utils
import pgsmo.utils.querying as querying
import pgsmo.utils.templating as templating
TEMPLATE_ROOT = utils.templating.get_template_root(__file__, 'templates')
TEMPLATE_ROOT = templating.get_template_root(__file__, 'templates')
class Table(node.NodeObject):
@classmethod
def get_nodes_for_parent(cls, conn: utils.querying.ConnectionWrapper, schema_id: int) -> List['Table']:
def get_nodes_for_parent(cls, conn: querying.ServerConnection, schema_id: int) -> List['Table']:
return node.get_nodes(conn, TEMPLATE_ROOT, cls._from_node_query, scid=schema_id)
@classmethod
def _from_node_query(cls, conn: utils.querying.ConnectionWrapper, **kwargs) -> 'Table':
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs) -> 'Table':
"""
Creates a table instance from the results of a node query
:param conn: The connection used to execute the node query
@ -34,7 +35,7 @@ class Table(node.NodeObject):
return table
def __init__(self, conn: utils.querying.ConnectionWrapper, name: str):
def __init__(self, conn: querying.ServerConnection, name: str):
super(Table, self).__init__(conn, name)
# Declare child items

Просмотреть файл

@ -6,14 +6,15 @@
from typing import List, Optional
from pgsmo.objects.node_object import NodeObject, get_nodes
import pgsmo.utils as utils
import pgsmo.utils.querying as querying
import pgsmo.utils.templating as templating
TEMPLATE_ROOT = utils.templating.get_template_root(__file__, 'templates')
TEMPLATE_ROOT = templating.get_template_root(__file__, 'templates')
class Tablespace(NodeObject):
@classmethod
def get_nodes_for_parent(cls, conn: utils.querying.ConnectionWrapper) -> List['Tablespace']:
def get_nodes_for_parent(cls, conn: querying.ServerConnection) -> List['Tablespace']:
"""
Creates a list of tablespaces that belong to the server. Intended to be called by Server class
:param conn: Connection to a server to use to lookup the information
@ -22,7 +23,7 @@ class Tablespace(NodeObject):
return get_nodes(conn, TEMPLATE_ROOT, cls._from_node_query)
@classmethod
def _from_node_query(cls, conn: utils.querying.ConnectionWrapper, **kwargs) -> 'Tablespace':
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs) -> 'Tablespace':
"""
Creates a tablespace from a row of a nodes query result
:param conn: Connection to a server to use to lookup the information
@ -36,7 +37,7 @@ class Tablespace(NodeObject):
return tablespace
def __init__(self, conn: utils.querying.ConnectionWrapper, name: str):
def __init__(self, conn: querying.ServerConnection, name: str):
"""
Initializes internal state of a Role object
:param conn: Connection that executed the role node query

Просмотреть файл

@ -8,19 +8,20 @@ from typing import List
from pgsmo.objects.column.column import Column
import pgsmo.objects.node_object as node
import pgsmo.utils as utils
import pgsmo.utils.querying as querying
import pgsmo.utils.templating as templating
TEMPLATE_ROOT = utils.templating.get_template_root(__file__, 'view_templates')
TEMPLATE_ROOT = templating.get_template_root(__file__, 'view_templates')
class View(node.NodeObject):
@classmethod
def get_nodes_for_parent(cls, conn: utils.querying.ConnectionWrapper, scid: int) -> List['View']:
def get_nodes_for_parent(cls, conn: querying.ServerConnection, scid: int) -> List['View']:
type_template_root = path.join(TEMPLATE_ROOT, conn.server_type)
return node.get_nodes(conn, type_template_root, cls._from_node_query, scid=scid)
@classmethod
def _from_node_query(cls, conn: utils.querying.ConnectionWrapper, **kwargs) -> 'View':
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs) -> 'View':
"""
Creates a view object from the results of a node query
:param conn: Connection used to execute the nodes query
@ -35,7 +36,7 @@ class View(node.NodeObject):
return view
def __init__(self, conn: utils.querying.ConnectionWrapper, name: str):
def __init__(self, conn: querying.ServerConnection, name: str):
super(View, self).__init__(conn, name)
# Declare child items

Просмотреть файл

@ -8,8 +8,14 @@ from typing import List, Mapping, Tuple
from psycopg2.extensions import Column, connection, cursor # noqa
class ConnectionWrapper:
class ServerConnection:
"""Wrapper for a psycopg2 connection that makes various properties easier to access"""
def __init__(self, conn: connection):
"""
Creates a new connection wrapper. Parses version string
:param conn: PsycoPG2 connection object
"""
self._conn = conn
self._dsn_parameters = conn.get_dsn_parameters()
@ -24,40 +30,35 @@ class ConnectionWrapper:
# PROPERTIES ###########################################################
@property
def connection(self) -> connection:
"""The psycopg2 connection that this object wraps"""
return self._conn
@property
def dsn_parameters(self) -> Mapping[str, str]:
"""DSN properties of the underlying connection"""
return self._dsn_parameters
@property
def server_type(self) -> str:
"""Server type for distinguishing between standard PG and PG supersets"""
return 'pg' # TODO: Determine if a server is PPAS or PG
@property
def version(self) -> Tuple[int, int, int]:
"""Tuple that splits version string into sensible values"""
return self._version
def execute_2d_array(conn: ConnectionWrapper, query: str, params=None) -> Tuple[List[Column], List[list]]:
cur: cursor = conn.connection.cursor()
try:
cur.execute(query, params)
cols: List[Column] = cur.description
rows: List[list] = []
if cur.rowcount > 0:
for row in cur:
rows.append(row)
return cols, rows
finally:
cur.close()
def execute_dict(conn: ConnectionWrapper, query: str, params=None) -> Tuple[List[Column], List[dict]]:
cur: cursor = conn.connection.cursor()
# METHODS ##############################################################
def execute_dict(self, query: str, params=None) -> Tuple[List[Column], List[dict]]:
"""
Executes a query and returns the results as an ordered list of dictionaries that map column
name to value. Columns are returned, as well.
:param conn: The connection to use to execute the query
:param query: The text of the query to execute
:param params: Optional parameters to inject into the query
:return: A list of column objects and a list of rows, which are formatted as dicts.
"""
cur: cursor = self._conn.cursor()
try:
cur.execute(query, params)

Просмотреть файл

@ -12,7 +12,7 @@ from psycopg2.extensions import adapt
TEMPLATE_ENVIRONMENTS: Dict[str, Environment] = {}
TEMPLATE_FOLDER_REGEX = re.compile('(\d+)\.(\d+)(?:_(\w+))?$')
TEMPLATE_NON_DISCOVERED_NAMES: List[str] = ['macros']
TEMPLATE_SKIPPED_FOLDERS: List[str] = ['macros']
def get_template_root(file_path: str, template_directory: str) -> str:
@ -29,7 +29,7 @@ def get_template_path(template_root: str, template_name: str, server_version: Tu
"""
# Step 1) Get the list of folders in the template folder that contains the
# Step 1.1) Get the list of folders in the template root folder
all_folders: List[str] = [x[0] for x in os.walk(template_root)]
all_folders: List[str] = [os.path.normpath(x[0]) for x in os.walk(template_root)]
# Step 1.2) Filter out the folders that don't contain the target template
containing_folders: List[str] = [x for x in all_folders if template_name in next(os.walk(x))[2]]
@ -40,16 +40,17 @@ def get_template_path(template_root: str, template_name: str, server_version: Tu
# Step 2) Iterate over the list of directories and check if the server version fits the bill
for folder in containing_folders:
# Skip over non-included folders
if os.path.dirname(folder) in TEMPLATE_NON_DISCOVERED_NAMES:
if os.path.basename(folder) in TEMPLATE_SKIPPED_FOLDERS:
continue
# If we are at the default, we are at the end of the list, so it is the only valid match
if folder.endswith('/+default'):
if folder.endswith(os.sep + '+default'):
return os.path.join(folder, template_name)
# Process the folder name with the regex
match = TEMPLATE_FOLDER_REGEX.search(folder)
if not match:
# This indicates a serious bug that shouldn't occur in production code, so this needn't be localized.
raise ValueError(f'Templates folder {template_root} contains improperly formatted folder name {folder}')
captures = match.groups()
major = int(captures[0])
@ -67,6 +68,7 @@ def get_template_path(template_root: str, template_name: str, server_version: Tu
# TODO: Modifier is minus
# If we make it to here, the template doesn't exist.
# This indicates a serious bug that shouldn't occur in production code, so this doesn't need to be localized.
raise ValueError(f'Template folder {template_root} does not contain {template_name}')
@ -75,6 +77,7 @@ def render_template(template_path: str, **context) -> str:
Renders a template from the template folder with the given context.
:param template_path: the path to the template to be rendered
:param context: the variables that should be available in the context of the template.
:return: The template rendered with the provided context
"""
path, filename = os.path.split(template_path)
if path not in TEMPLATE_ENVIRONMENTS:
@ -85,13 +88,13 @@ def render_template(template_path: str, **context) -> str:
# Create the environment and add the basic filters
new_env: Environment = Environment(loader=loader)
new_env.filters['qtLiteral'] = qtLiteral
new_env.filters['qtIdent'] = qtIdent
new_env.filters['qtTypeIdent'] = qtTypeIdent
new_env.filters['qtLiteral'] = qt_literal
new_env.filters['qtIdent'] = qt_ident
new_env.filters['qtTypeIdent'] = qt_type_ident
TEMPLATE_ENVIRONMENTS[template_path] = new_env
TEMPLATE_ENVIRONMENTS[path] = new_env
env = TEMPLATE_ENVIRONMENTS[template_path]
env = TEMPLATE_ENVIRONMENTS[path]
return env.get_template(filename).render(context)
@ -101,6 +104,7 @@ def render_template_string(source, **context):
autoescaped.
:param source: the source code of the template to be rendered
:param context: the variables that should be available in the context of the template.
:return: The template rendered with the provided context
"""
template = Template(source)
return template.render(context)
@ -115,7 +119,7 @@ def render_template_string(source, **context):
##########################################################################
def qtLiteral(value):
def qt_literal(value):
adapted = adapt(value)
# Not all adapted objects have encoding
@ -133,7 +137,7 @@ def qtLiteral(value):
return res
def qtTypeIdent(conn, *args):
def qt_type_ident(conn, *args):
# We're not using the conn object at the moment, but - we will modify the
# logic to use the server version specific keywords later.
res = None
@ -144,7 +148,7 @@ def qtTypeIdent(conn, *args):
continue
value = val
if needsQuoting(val, True):
if needs_quoting(val, True):
value = value.replace("\"", "\"\"")
value = "\"" + value + "\""
@ -153,7 +157,7 @@ def qtTypeIdent(conn, *args):
return res
def qtIdent(conn, *args):
def qt_ident(conn, *args):
# We're not using the conn object at the moment, but - we will modify the
# logic to use the server version specific keywords later.
res = None
@ -161,14 +165,14 @@ def qtIdent(conn, *args):
for val in args:
if type(val) == list:
return map(lambda w: qtIdent(conn, w), val)
return map(lambda w: qt_ident(conn, w), val)
if len(val) == 0:
continue
value = val
if needsQuoting(val, False):
if needs_quoting(val, False):
value = value.replace("\"", "\"\"")
value = "\"" + value + "\""
@ -177,18 +181,18 @@ def qtIdent(conn, *args):
return res
def needsQuoting(key, forTypes):
def needs_quoting(key, for_types):
value = key
valNoArray = value
val_no_array = value
# check if the string is number or not
if isinstance(value, int):
return True
# certain types should not be quoted even though it contains a space. Evilness.
elif forTypes and value[-2:] == u"[]":
valNoArray = value[:-2]
elif for_types and value[-2:] == u"[]":
val_no_array = value[:-2]
if forTypes and valNoArray.lower() in [
if for_types and val_no_array.lower() in [
u'bit varying',
u'"char"',
u'character varying',
@ -203,21 +207,21 @@ def needsQuoting(key, forTypes):
return False
# If already quoted?, If yes then do not quote again
if forTypes and valNoArray:
if valNoArray.startswith('"') \
or valNoArray.endswith('"'):
if for_types and val_no_array:
if val_no_array.startswith('"') \
or val_no_array.endswith('"'):
return False
if u'0' <= valNoArray[0] <= u'9':
if u'0' <= val_no_array[0] <= u'9':
return True
for c in valNoArray:
for c in val_no_array:
if (not (u'a' <= c <= u'z') and c != u'_' and
not (u'0' <= c <= u'9')):
return True
# check string is keywaord or not
category = ScanKeywordExtraLookup(value)
category = scan_keyword_extra_lookup(value)
if category is None:
return False
@ -227,18 +231,21 @@ def needsQuoting(key, forTypes):
return False
# COL_NAME_KEYWORD
if forTypes and category == 1:
if for_types and category == 1:
return False
return True
def ScanKeywordExtraLookup(key):
# UNRESERVED_KEYWORD 0
# COL_NAME_KEYWORD 1
# TYPE_FUNC_NAME_KEYWORD 2
# RESERVED_KEYWORD 3
extraKeywords = {
def scan_keyword_extra_lookup(key):
return (key in _EXTRA_KEYWORDS and _EXTRA_KEYWORDS[key]) or _KEYWORD_DICT.get(key)
# UNRESERVED_KEYWORD 0
# COL_NAME_KEYWORD 1
# TYPE_FUNC_NAME_KEYWORD 2
# RESERVED_KEYWORD 3
_EXTRA_KEYWORDS = {
'connect': 3,
'convert': 3,
'distributed': 0,
@ -260,14 +267,10 @@ def ScanKeywordExtraLookup(key):
'tinyint': 3,
'tinytext': 3,
'varchar2': 3
}
}
return (key in extraKeywords and extraKeywords[key]) or ScanKeyword(key)
def ScanKeyword(key):
# ScanKeyword function for PostgreSQL 9.5rc1
keywordDict = {
# ScanKeyword function for PostgreSQL 9.5rc1
_KEYWORD_DICT = {
"abort": 0, "absolute": 0, "access": 0, "action": 0, "add": 0, "admin": 0, "after": 0, "aggregate": 0, "all": 3,
"also": 0, "alter": 0, "always": 0, "analyze": 3, "and": 3, "any": 3, "array": 3, "as": 3, "asc": 3,
"assertion": 0, "assignment": 0, "asymmetric": 3, "at": 0, "attribute": 0, "authorization": 2, "backward": 0,
@ -321,5 +324,4 @@ def ScanKeyword(key):
"with": 3, "within": 0, "without": 0, "work": 0, "wrapper": 0, "write": 0, "xml": 0, "xmlattributes": 1,
"xmlconcat": 1, "xmlelement": 1, "xmlexists": 1, "xmlforest": 1, "xmlparse": 1, "xmlpi": 1, "xmlroot": 1,
"xmlserialize": 1, "year": 0, "yes": 0, "zone": 0
}
return keywordDict.get(key)
}

Просмотреть файл

Просмотреть файл

@ -0,0 +1 @@
{{foo}}

Просмотреть файл

@ -125,23 +125,22 @@ class TestNodeObject(unittest.TestCase):
def test_get_nodes(self):
# Setup:
# ... Create a mockup of a connection wrapper
# ... Create a mockup of a server connection with a mock executor
version = (1, 1, 1)
class MockConn:
def __init__(self):
self.version = version
mock_objs = [{'name': 'abc', 'oid': 123}, {'name': 'def', 'oid': 456}]
mock_executor = mock.MagicMock(return_value=([{}, {}], mock_objs))
mock_conn = MockConn()
mock_conn.execute_dict = mock_executor
# ... Create a mock template renderer
mock_render = mock.MagicMock(return_value="SQL")
mock_template_path = mock.MagicMock(return_value="path")
# ... Create a mock query executor
mock_objs = [{'name': 'abc', 'oid': 123}, {'name': 'def', 'oid': 456}]
mock_executor = mock.MagicMock(return_value=([{}, {}], mock_objs))
# ... Create a mock generator
mock_output = {}
mock_generator = mock.MagicMock(return_value=mock_output)
@ -149,7 +148,6 @@ class TestNodeObject(unittest.TestCase):
# ... Do the patching
with mock.patch('pgsmo.objects.node_object.templating.render_template', mock_render, create=True):
with mock.patch('pgsmo.objects.node_object.templating.get_template_path', mock_template_path, create=True):
with mock.patch('pgsmo.objects.node_object.querying.execute_dict', mock_executor, create=True):
# If: I ask for a collection of nodes
kwargs = {'arg1': 'something'}
nodes = node.get_nodes(mock_conn, 'root', mock_generator, **kwargs)
@ -162,7 +160,7 @@ class TestNodeObject(unittest.TestCase):
mock_render.assert_called_once_with('path', **kwargs)
# ... A query should have been executed
mock_executor.assert_called_once_with(mock_conn, 'SQL')
mock_executor.assert_called_once_with('SQL')
# ... The generator should have been called twice with different object props
mock_generator.assert_any_call(mock_conn, **mock_objs[0])

Просмотреть файл

@ -0,0 +1,65 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import unittest
from pgsmo.objects.column.column import Column
from pgsmo.utils.querying import ServerConnection
import tests.pgsmo_tests.utils as utils
COLUMN_ROW = {
'name': 'abc',
'datatype': 'character',
'oid': 123,
'has_default_val': True,
'not_null': True
}
class TestColumn(unittest.TestCase):
# CONSTRUCTION TESTS ###################################################
def test_init(self):
props = [
'has_default_value', '_has_default_value',
'not_null', '_not_null'
]
colls = []
name = 'column'
datatype = 'character'
mock_conn = ServerConnection(utils.MockConnection(None))
obj = Column(mock_conn, name, datatype)
utils.validate_init(
Column, name, mock_conn, obj, props, colls,
lambda obj: self._validate_init(obj, datatype)
)
def test_from_node_query(self):
utils.from_node_query_base(Column, COLUMN_ROW, self._validate_column)
def test_get_nodes_for_parent(self):
# Use the test helper for this method
get_nodes_for_parent = (lambda conn: Column.get_nodes_for_parent(conn, 10))
utils.get_nodes_for_parent_base(Column, COLUMN_ROW, get_nodes_for_parent, self._validate_column)
# IMPLEMENTATION DETAILS ###############################################
def _validate_init(self, obj: Column, datatype: str):
self.assertEqual(obj._datatype, datatype)
self.assertEqual(obj.datatype, datatype)
def _validate_column(self, obj: Column, mock_conn: ServerConnection):
# NodeObject basic properties
self.assertIs(obj._conn, mock_conn)
self.assertEqual(obj._oid, COLUMN_ROW['oid'])
self.assertEqual(obj.oid, COLUMN_ROW['oid'])
self.assertEqual(obj._name, COLUMN_ROW['name'])
self.assertEqual(obj.name, COLUMN_ROW['name'])
# Column-specific basic properties
self.assertEqual(obj._datatype, COLUMN_ROW['datatype'])
self.assertEqual(obj.datatype, COLUMN_ROW['datatype'])
self.assertEqual(obj._has_default_value, COLUMN_ROW['has_default_val'])
self.assertEqual(obj.has_default_value, COLUMN_ROW['has_default_val'])
self.assertEqual(obj._not_null, COLUMN_ROW['not_null'])
self.assertEqual(obj.not_null, COLUMN_ROW['not_null'])

Просмотреть файл

@ -0,0 +1,101 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import unittest
import unittest.mock as mock
from pgsmo.objects.database.database import Database
from pgsmo.utils.querying import ServerConnection
import tests.pgsmo_tests.utils as utils
DATABASE_ROW = {
'name': 'dbname',
'did': 123,
'spcname': 'primary',
'datallowconn': True,
'cancreate': True,
'owner': 10
}
class TestDatabase(unittest.TestCase):
# CONSTRUCTION TESTS ###################################################
def test_init_connected(self):
props = [
'_tablespace', 'tablespace',
'_allow_conn', 'allow_conn',
'_can_create', 'can_create',
'_owner_oid'
]
colls = ['_schemas', 'schemas'] # When connected, these are actually defined
name = 'dbname'
mock_conn = ServerConnection(utils.MockConnection(None, name=name))
db = Database(mock_conn, name)
utils.validate_init(
Database, name, mock_conn, db, props, colls,
lambda obj: self._init_validation(obj, True)
)
def test_init_disconnected(self):
props = [
'_tablespace', 'tablespace',
'_allow_conn', 'allow_conn',
'_can_create', 'can_create',
'_owner_oid',
'_schemas', 'schemas' # When not connected we want these to be set to None
]
colls = []
name = 'dbname'
mock_conn = ServerConnection(utils.MockConnection(None, name='notconnected'))
db = Database(mock_conn, name)
utils.validate_init(
Database, name, mock_conn, db, props, colls,
lambda obj: self._init_validation(obj, False)
)
def test_from_node_query(self):
utils.from_node_query_base(Database, DATABASE_ROW, self._validate_database)
def test_get_nodes_for_parent(self):
# Use the test helper for this method
utils.get_nodes_for_parent_base(Database, DATABASE_ROW, Database.get_nodes_for_parent, self._validate_database)
# METHOD TESTS #########################################################
def test_refresh(self):
# Setup: Create a database object and overwrite the reset method of the child objects
db = Database(ServerConnection(utils.MockConnection(None, name=DATABASE_ROW['name'])), DATABASE_ROW['name'])
db._schemas.reset = mock.MagicMock()
# If: I refresh the database
db.refresh()
# Then: The mocks should have been called
db._schemas.reset.assert_called_once()
# IMPLEMENTATION DETAILS ###############################################
def _init_validation(self, obj: Database, is_connected: bool):
self.assertEqual(obj._is_connected, is_connected)
def _validate_database(self, db: Database, mock_conn: ServerConnection):
# NodeObject basic properties
self.assertIs(db._conn, mock_conn)
self.assertEqual(db._oid, DATABASE_ROW['did'])
self.assertEqual(db.oid, DATABASE_ROW['did'])
self.assertEqual(db._name, DATABASE_ROW['name'])
self.assertEqual(db.name, DATABASE_ROW['name'])
# Database-specific basic properties
self.assertEqual(db._tablespace, DATABASE_ROW['spcname'])
self.assertEqual(db.tablespace, DATABASE_ROW['spcname'])
self.assertEqual(db._allow_conn, DATABASE_ROW['datallowconn'])
self.assertEqual(db.allow_conn, DATABASE_ROW['datallowconn'])
self.assertEqual(db._can_create, DATABASE_ROW['cancreate'])
self.assertEqual(db.can_create, DATABASE_ROW['cancreate'])
self.assertEqual(db._owner_oid, DATABASE_ROW['owner'])
# Child objects
self.assertFalse(db._is_connected)
self.assertIsNone(db._schemas)
self.assertIsNone(db.schemas)

Просмотреть файл

@ -0,0 +1,48 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import unittest
from pgsmo.objects.role.role import Role
from pgsmo.utils.querying import ServerConnection
import tests.pgsmo_tests.utils as utils
ROLE_ROW = {
'rolname': 'role',
'oid': 123,
'rolcanlogin': True,
'rolsuper': True
}
class TestRole(unittest.TestCase):
# CONSTRUCTION TESTS ###################################################
def test_init(self):
props = ['_can_login', 'can_login', '_super', 'super']
colls = []
utils.init_base(Role, props, colls)
def test_from_node_query(self):
utils.from_node_query_base(Role, ROLE_ROW, self._validate_role)
def test_get_nodes_for_parent(self):
# Use the test helper
utils.get_nodes_for_parent_base(Role, ROLE_ROW, Role.get_nodes_for_parent, self._validate_role)
# IMPLEMENTATION DETAILS ###############################################
def _validate_role(self, role: Role, mock_conn: ServerConnection):
# NodeObject basic properties
self.assertIs(role._conn, mock_conn)
self.assertEqual(role._oid, ROLE_ROW['oid'])
self.assertEqual(role.oid, ROLE_ROW['oid'])
self.assertEqual(role._name, ROLE_ROW['rolname'])
self.assertEqual(role.name, ROLE_ROW['rolname'])
# Role-specific basic properties
self.assertEqual(role._can_login, ROLE_ROW['rolcanlogin'])
self.assertEqual(role.can_login, ROLE_ROW['rolcanlogin'])
self.assertEqual(role._super, ROLE_ROW['rolsuper'])
self.assertEqual(role.super, ROLE_ROW['rolsuper'])

Просмотреть файл

@ -0,0 +1,70 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import unittest
import unittest.mock as mock
from pgsmo.objects.node_object import NodeCollection
from pgsmo.objects.schema.schema import Schema
from pgsmo.utils.querying import ServerConnection
import tests.pgsmo_tests.utils as utils
SCHEMA_ROW = {
'name': 'schema',
'oid': 123,
'can_create': True,
'has_usage': True
}
class TestSchema(unittest.TestCase):
# CONSTRUCTION TESTS ###################################################
def test_init(self):
props = ['_can_create', 'can_create', '_has_usage', 'has_usage']
collections = ['_tables', 'tables', '_views', 'views']
utils.init_base(Schema, props, collections)
def test_from_node_query(self):
utils.from_node_query_base(Schema, SCHEMA_ROW, self._validate_schema)
def test_from_nodes_for_parent(self):
# Use the test helper for this method
utils.get_nodes_for_parent_base(Schema, SCHEMA_ROW, Schema.get_nodes_for_parent, self._validate_schema)
# METHOD TESTS #########################################################
def test_refresh(self):
# Setup: Create a schema object and mock up the node collection reset methods
mock_conn = ServerConnection(utils.MockConnection(None))
schema = Schema._from_node_query(mock_conn, **SCHEMA_ROW)
schema._tables.reset = mock.MagicMock()
schema._views.reset = mock.MagicMock()
# If: I refresh a schema object
schema.refresh()
# Then: The child object node collections should have been reset
schema._tables.reset.assert_called_once()
schema._views.reset.assert_called_once()
def _validate_schema(self, schema: Schema, mock_conn: ServerConnection):
# NodeObject basic properties
self.assertIs(schema._conn, mock_conn)
self.assertEqual(schema._oid, SCHEMA_ROW['oid'])
self.assertEqual(schema.oid, SCHEMA_ROW['oid'])
self.assertEqual(schema._name, SCHEMA_ROW['name'])
self.assertEqual(schema.name, SCHEMA_ROW['name'])
# Schema-specific basic properties
self.assertEqual(schema._can_create, SCHEMA_ROW['can_create'])
self.assertEqual(schema.can_create, SCHEMA_ROW['can_create'])
self.assertEqual(schema._has_usage, SCHEMA_ROW['has_usage'])
self.assertEqual(schema.has_usage, SCHEMA_ROW['has_usage'])
# Child objects
self.assertIsInstance(schema._tables, NodeCollection)
self.assertIs(schema.tables, schema._tables)
self.assertIsInstance(schema._views, NodeCollection)
self.assertIs(schema.views, schema._views)

Просмотреть файл

@ -0,0 +1,48 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import unittest
from pgsmo.objects.node_object import NodeCollection
from pgsmo.objects.server.server import Server
from pgsmo.utils.querying import ServerConnection
import tests.pgsmo_tests.utils as utils
class TestServer(unittest.TestCase):
def test_init(self):
# If: I construct a new server object
host = 'host'
port = '1234'
dbname = 'dbname'
mock_conn = utils.MockConnection(None, name=dbname, host=host, port=port)
server = Server(mock_conn)
# Then:
# ... The assigned properties should be assigned
self.assertIsInstance(server._conn, ServerConnection)
self.assertIsInstance(server.connection, ServerConnection)
self.assertIs(server.connection.connection, mock_conn)
self.assertEqual(server._host, host)
self.assertEqual(server.host, host)
self.assertEqual(server._port, int(port))
self.assertEqual(server.port, int(port))
self.assertEqual(server._maintenance_db, dbname)
self.assertEqual(server.maintenance_db, dbname)
self.assertTupleEqual(server.version, server._conn.version)
# ... The optional properties should be assigned to None
self.assertIsNone(server._in_recovery)
self.assertIsNone(server.in_recovery)
self.assertIsNone(server._wal_paused)
self.assertIsNone(server.wal_paused)
# ... The child object collections should be assigned to NodeCollections
self.assertIsInstance(server._databases, NodeCollection)
self.assertIs(server.databases, server._databases)
self.assertIsInstance(server._roles, NodeCollection)
self.assertIs(server.roles, server._roles)
self.assertIsInstance(server._tablespaces, NodeCollection)
self.assertIs(server.tablespaces, server._tablespaces)

Просмотреть файл

@ -0,0 +1,62 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import unittest
import unittest.mock as mock
from pgsmo.objects.node_object import NodeCollection
from pgsmo.objects.table.table import Table
from pgsmo.utils.querying import ServerConnection
import tests.pgsmo_tests.utils as utils
TABLE_ROW = {
'name': 'tablename',
'oid': 123
}
class TestTable(unittest.TestCase):
# CONSTRUCTION TESTS ###################################################
def test_init(self):
props = []
colls = ['_columns', 'columns']
utils.init_base(Table, props, colls)
def test_from_node_query(self):
utils.from_node_query_base(Table, TABLE_ROW, self._validate_table)
def test_from_nodes_for_parent(self):
utils.get_nodes_for_parent_base(
Table,
TABLE_ROW,
lambda conn: Table.get_nodes_for_parent(conn, 0),
self._validate_table
)
# METHOD TESTS #########################################################
def test_refresh(self):
# Setup: Create a table object and mock up the node collection reset method
mock_conn = ServerConnection(utils.MockConnection(None))
table = Table._from_node_query(mock_conn, **TABLE_ROW)
table._columns.reset = mock.MagicMock()
# If: I refresh a table object
table.refresh()
# Then: The child object node collections should have been reset
table._columns.reset.assert_called_once()
# IMPLEMENTATION DETAILS ###############################################
def _validate_table(self, table: Table, mock_conn: ServerConnection):
# NodeObject basic properties
self.assertIs(table._conn, mock_conn)
self.assertEqual(table._oid, TABLE_ROW['oid'])
self.assertEqual(table.oid, TABLE_ROW['oid'])
self.assertEqual(table._name, TABLE_ROW['name'])
self.assertEqual(table.name, TABLE_ROW['name'])
# Child objects
self.assertIsInstance(table._columns, NodeCollection)
self.assertIs(table.columns, table._columns)

Просмотреть файл

@ -0,0 +1,47 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import unittest
from pgsmo.objects.tablespace.tablespace import Tablespace
from pgsmo.utils.querying import ServerConnection
import tests.pgsmo_tests.utils as utils
NODE_ROW = {
'name': 'test',
'oid': 123,
'owner': 10
}
class TestTablespace(unittest.TestCase):
# CONSTRUCTION TESTS ###################################################
def test_init(self):
props = ['owner', '_owner']
utils.init_base(Tablespace, props, [])
def test_from_node_query(self):
utils.from_node_query_base(Tablespace, NODE_ROW, self._validate_tablespace)
def test_get_nodes_for_parent(self):
utils.get_nodes_for_parent_base(
Tablespace,
NODE_ROW,
Tablespace.get_nodes_for_parent,
self._validate_tablespace
)
# IMPLEMENTATION DETAILS ###############################################
def _validate_tablespace(self, obj: Tablespace, mock_conn: ServerConnection):
# NodeObject basic properties
self.assertIs(obj._conn, mock_conn)
self.assertEqual(obj._oid, NODE_ROW['oid'])
self.assertEqual(obj.oid, NODE_ROW['oid'])
self.assertEqual(obj._name, NODE_ROW['name'])
self.assertEqual(obj.name, NODE_ROW['name'])
# Tablespace-specific properties
self.assertEqual(obj._owner, NODE_ROW['owner'])
self.assertEqual(obj.owner, NODE_ROW['owner'])

Просмотреть файл

@ -0,0 +1,62 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import unittest
import unittest.mock as mock
from pgsmo.objects.node_object import NodeCollection
from pgsmo.objects.view.view import View
from pgsmo.utils.querying import ServerConnection
import tests.pgsmo_tests.utils as utils
NODE_ROW = {
'name': 'viewname',
'oid': 123
}
class TestTable(unittest.TestCase):
# CONSTRUCTION TESTS ###################################################
def test_init(self):
props = []
colls = ['_columns', 'columns']
utils.init_base(View, props, colls)
def test_from_node_query(self):
utils.from_node_query_base(View, NODE_ROW, self._validate_view)
def test_from_nodes_for_parent(self):
utils.get_nodes_for_parent_base(
View,
NODE_ROW,
lambda conn: View.get_nodes_for_parent(conn, 0),
self._validate_view
)
# METHOD TESTS #########################################################
def test_refresh(self):
# Setup: Create a table object and mock up the node collection reset method
mock_conn = ServerConnection(utils.MockConnection(None))
table = View._from_node_query(mock_conn, **NODE_ROW)
table._columns.reset = mock.MagicMock()
# If: I refresh a table object
table.refresh()
# Then: The child object node collections should have been reset
table._columns.reset.assert_called_once()
# IMPLEMENTATION DETAILS ###############################################
def _validate_view(self, view: View, mock_conn: ServerConnection):
# NodeObject basic properties
self.assertIs(view._conn, mock_conn)
self.assertEqual(view._oid, NODE_ROW['oid'])
self.assertEqual(view.oid, NODE_ROW['oid'])
self.assertEqual(view._name, NODE_ROW['name'])
self.assertEqual(view.name, NODE_ROW['name'])
# Child objects
self.assertIsInstance(view._columns, NodeCollection)
self.assertIs(view.columns, view._columns)

Просмотреть файл

@ -0,0 +1,73 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import unittest
import pgsmo.utils as pgsmo_utils
import tests.pgsmo_tests.utils as utils
class TestServerConnection(unittest.TestCase):
def test_server_conn_init(self):
# Setup: Create a mock connection with an 'interesting' version
mock_conn = utils.MockConnection(None, version='100216')
# If: I initialize a server connection
# noinspection PyTypeChecker
server_conn = pgsmo_utils.querying.ServerConnection(mock_conn)
# Then: The properties should be properly set
self.assertEqual(server_conn._conn, mock_conn)
self.assertEqual(server_conn.connection, mock_conn)
expected_dict = {'dbname': 'postgres', 'host': 'localhost', 'port': '25565'}
self.assertDictEqual(server_conn._dsn_parameters, expected_dict)
self.assertDictEqual(server_conn.dsn_parameters, expected_dict)
self.assertEqual('pg', server_conn.server_type)
self.assertTupleEqual((10, 2, 16), server_conn.version)
def test_execute_dict_success(self):
# Setup: Create a mock server connection that will return a result set
mock_cursor = utils.MockCursor(utils.get_mock_results())
mock_conn = utils.MockConnection(mock_cursor)
# noinspection PyTypeChecker
server_conn = pgsmo_utils.querying.ServerConnection(mock_conn)
# If: I execute a query as a dictionary
results = server_conn.execute_dict('SELECT * FROM pg_class')
# Then:
# ... Both the columns and the results should be returned
self.assertIsInstance(results, tuple)
self.assertEqual(len(results), 2)
# ... I should get a list of columns returned to me
cols = results[0]
self.assertIsInstance(cols, list)
self.assertListEqual(cols, mock_cursor.description)
# ... I should get the results formatted as a list of dictionaries
rows = results[1]
self.assertIsInstance(rows, list)
for idx, row in enumerate(rows):
self.assertDictEqual(row, mock_cursor._results[1][idx])
# ... The cursor should be closed
mock_cursor.close.assert_called_once()
def test_execute_dict_fail(self):
# Setup: Create a mock server connection that will raise an exception
mock_cursor = utils.MockCursor(None, throw_on_execute=True)
mock_conn = utils.MockConnection(mock_cursor)
# noinspection PyTypeChecker
server_conn = pgsmo_utils.querying.ServerConnection(mock_conn)
# If: I execute a query as a dictionary
# Then:
# ... I should get an exception
with self.assertRaises(Exception):
server_conn.execute_dict('SELECT * FROM pg_class')
# ... The cursor should be closed
mock_cursor.close.assert_called_once()

Просмотреть файл

@ -0,0 +1,178 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import os.path as path
import unittest
import unittest.mock as mock
import jinja2
import pgsmo.utils as pgsmo_utils
class TestTemplatingUtils(unittest.TestCase):
# GET_TEMPLATE_ROOT TESTS ##############################################
def test_get_template_root(self):
# If: I attempt to get the template root of this file
root = pgsmo_utils.templating.get_template_root(__file__, 'templates')
# Then: The output should match what I expected
expected = path.join(path.dirname(__file__), 'templates')
self.assertEqual(root, expected)
# GET_TEMPLATE_PATH TESTS ##############################################
def test_get_template_path_no_match(self):
# Setup: Create a mock os walker
with mock.patch('pgsmo.utils.templating.os.walk', _os_walker, create=True):
with self.assertRaises(ValueError):
# If: I attempt to get the path of a template when it does not exist
# Then: An exception should be thrown
pgsmo_utils.templating.get_template_path(TEMPLATE_ROOT_NAME, 'doesnotexist.sql', (9, 0, 0))
def test_get_template_path_default(self):
# Setup: Create a mock os walker
with mock.patch('pgsmo.utils.templating.os.walk', _os_walker, create=True):
# If: I attempt to get a template path when there is not a good version match
template_path = pgsmo_utils.templating.get_template_path(TEMPLATE_ROOT_NAME, 'template.sql', (8, 1, 0))
# Then: The template path should point to the default folder
self.assertEqual(template_path, path.join(TEMPLATE_ROOT_NAME, '+default', 'template.sql'))
def test_get_template_path_exact_version_match(self):
# Setup: Create a mock os walker
with mock.patch('pgsmo.utils.templating.os.walk', _os_walker, create=True):
# If: I attempt to get the path of a template when there's a exact version match
template_path = pgsmo_utils.templating.get_template_path(TEMPLATE_ROOT_NAME, 'template.sql', (9, 0, 0))
# Then: The returned path must match the exact version
self.assertEqual(template_path, path.join(TEMPLATE_ROOT_NAME, '9.0', 'template.sql'))
def test_get_template_path_plus_match(self):
# Setup: Create a mock os walker
with mock.patch('pgsmo.utils.templating.os.walk', _os_walker, create=True):
# If: I attempt to get a template path when there is a version range that matches
template_path = pgsmo_utils.templating.get_template_path(TEMPLATE_ROOT_NAME, 'template.sql', (9, 3, 0))
# Then: The returned path should match the next lowest version range
self.assertEqual(template_path, path.join(TEMPLATE_ROOT_NAME, '9.2_plus', 'template.sql'))
def test_get_template_path_invalid_folder(self):
# Setup: Create a mock os walker that has an invalid folder in it
with mock.patch('pgsmo.utils.templating.os.walk', _bad_os_walker, create=True):
with self.assertRaises(ValueError):
# If: I attempt to get a template path when there is an invalid folder in the template folder
# Then: I should get an exception
pgsmo_utils.templating.get_template_path(TEMPLATE_ROOT_NAME, 'template.sql', (9, 0, 0))
# RENDER_TEMPLATE TESTS ################################################
def test_render_template(self):
# NOTE: This test has an external dependency on dummy_template.txt
# If: I render a string
template_file = 'dummy_template.txt'
template_folder = path.dirname(__file__)
template_path = path.normpath(path.join(template_folder, template_file))
rendered = pgsmo_utils.templating.render_template(template_path, foo='bar')
# Then:
# ... The output should be properly rendered
self.assertEqual(rendered, 'bar')
# ... The environment should be cached
self.assertIsInstance(pgsmo_utils.templating.TEMPLATE_ENVIRONMENTS, dict)
env = pgsmo_utils.templating.TEMPLATE_ENVIRONMENTS.get(template_folder)
self.assertIsInstance(env, jinja2.Environment)
# ... The environment should have the proper filters defined
self.assertEquals(env.filters['qtLiteral'], pgsmo_utils.templating.qt_literal)
self.assertEquals(env.filters['qtIdent'], pgsmo_utils.templating.qt_ident)
self.assertEquals(env.filters['qtTypeIdent'], pgsmo_utils.templating.qt_type_ident)
def test_render_template_cached(self):
# NOTE: This test has an external dependency on dummy_template.txt
# If:
# ... I render a string
template_file = 'dummy_template.txt'
template_folder = path.dirname(__file__)
template_path = path.normpath(path.join(template_folder, template_file))
rendered1 = pgsmo_utils.templating.render_template(template_path, foo='bar')
env1 = pgsmo_utils.templating.TEMPLATE_ENVIRONMENTS.get(template_folder)
# ... I render the same string
rendered2 = pgsmo_utils.templating.render_template(template_path, foo='bar')
env2 = pgsmo_utils.templating.TEMPLATE_ENVIRONMENTS.get(template_folder)
# Then: The environments used should be literally the same
self.assertEqual(rendered1, rendered2)
self.assertIs(env1, env2)
# RENDER_TEMPLATE_STRING TESTS #########################################
def test_render_template_string(self):
# NOTE: doing very minimal test here since this function just uses jinja2 functionality
# If: I render a string
rendered = pgsmo_utils.templating.render_template_string('{{foo}}', foo='bar')
# Then: The string should be properly rendered
self.assertEqual(rendered, 'bar')
# Mock setup for a template tree
TEMPLATE_ROOT_NAME = 'templates'
TEMPLATE_EXACT_1 = (path.join(TEMPLATE_ROOT_NAME, '9.0'), [], ['template.sql'])
TEMPLATE_EXACT_2 = (path.join(TEMPLATE_ROOT_NAME, '9.1'), [], ['template.sql'])
TEMPLATE_PLUS_1 = (path.join(TEMPLATE_ROOT_NAME, '9.2_plus'), [], ['template.sql'])
TEMPLATE_PLUS_2 = (path.join(TEMPLATE_ROOT_NAME, '9.4_plus'), [], ['template.sql'])
TEMPLATE_DEFAULT = (path.join(TEMPLATE_ROOT_NAME, '+default'), [], ['template.sql'])
# Tests that skipped folders are not returned
TEMPLATE_SKIP = (path.join(TEMPLATE_ROOT_NAME, 'macros'), [], ['template.sql'])
# Root tree
TEMPLATE_ROOT = (
TEMPLATE_ROOT_NAME,
[
TEMPLATE_DEFAULT[0],
TEMPLATE_EXACT_1[0],
TEMPLATE_EXACT_2[0],
TEMPLATE_PLUS_1[0],
TEMPLATE_PLUS_2[0],
TEMPLATE_SKIP[0]
],
[]
)
def _os_walker(x):
if x == TEMPLATE_ROOT_NAME:
yield TEMPLATE_ROOT
yield TEMPLATE_DEFAULT
yield TEMPLATE_EXACT_1
yield TEMPLATE_EXACT_2
yield TEMPLATE_PLUS_1
yield TEMPLATE_PLUS_2
yield TEMPLATE_SKIP
if x == TEMPLATE_DEFAULT[0]:
yield TEMPLATE_DEFAULT
if x == TEMPLATE_EXACT_1[0]:
yield TEMPLATE_EXACT_1
if x == TEMPLATE_EXACT_2[0]:
yield TEMPLATE_EXACT_2
if x == TEMPLATE_PLUS_1[0]:
yield TEMPLATE_PLUS_1
if x == TEMPLATE_PLUS_2[0]:
yield TEMPLATE_PLUS_2
if x == TEMPLATE_SKIP[0]:
yield TEMPLATE_SKIP
TEMPLATE_BAD = (path.join(TEMPLATE_ROOT_NAME, 'bad_folder'), [], ['template.sql'])
TEMPLATE_BAD_ROOT = (TEMPLATE_ROOT_NAME, [TEMPLATE_BAD[0]], [])
def _bad_os_walker(x):
if x == TEMPLATE_ROOT_NAME:
yield TEMPLATE_BAD_ROOT
yield TEMPLATE_BAD
if x == TEMPLATE_BAD[0]:
yield TEMPLATE_BAD

Просмотреть файл

@ -0,0 +1,84 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import unittest
import pgsmo.utils as pgsmo_utils
class TestTemplatingFilters(unittest.TestCase):
# SCAN KEYWORD EXTRA LOOKUP TESTS ######################################
def test_scan_keyword_extra_lookup_extra(self):
# If: I scan for a keyword that exists in the extra keywords
output = pgsmo_utils.templating.scan_keyword_extra_lookup('connect')
# Then: I should get back the extra keyword info
self.assertEqual(output, pgsmo_utils.templating._EXTRA_KEYWORDS['connect'])
def test_scan_keyword_extra_lookup_standard(self):
# If: I scan for a keyword that exists in the standard keywords
output = pgsmo_utils.templating.scan_keyword_extra_lookup('abort')
# Then: I should get back the standard keyword value
self.assertEqual(output, pgsmo_utils.templating._KEYWORD_DICT['abort'])
def test_scan_keyword_extra_lookup_none(self):
# If: I scan for a keyword that doesn't exist
output = pgsmo_utils.templating.scan_keyword_extra_lookup('does_not_exist')
# Then: I should get back None
self.assertIsNone(output)
# QTLITERAL TESTS ######################################################
def test_qtliteral_no_encoding(self):
# If: I provide a value that doesn't have an encoding
output = pgsmo_utils.templating.qt_literal(4)
# Then: I should get a quoted literal back
self.assertEqual(output, '4')
def test_qtliteral_bytes(self):
# If: I provide a byte array
output = pgsmo_utils.templating.qt_literal(b'123')
# Then: I should get the byte array decoded as UTF-8 back
self.assertEqual(output, "'123'::bytea")
def test_qtliteral_string(self):
# If: I provide a string
output = pgsmo_utils.templating.qt_literal("123")
# Then: I should get the quoted string back
self.assertEqual(output, "'123'")
# NEEDS QUOTING TESTS ##################################################
# TODO: Add tests that are more based on scenarios and less on code paths
def test_needs_quoting_int(self):
# If: An int is provided
# Then: It should always be quoted
self.assertTrue(pgsmo_utils.templating.needs_quoting(4, False))
self.assertTrue(pgsmo_utils.templating.needs_quoting(4, True))
def test_needs_quoting_type_legal_spaces(self):
# If: A type is provided that legally has a space in it
# Then: It shouldn't be quoted
self.assertFalse(pgsmo_utils.templating.needs_quoting('time with time zone', True))
self.assertFalse(pgsmo_utils.templating.needs_quoting('time with time zone[]', True))
def test_needs_quoting_type_already_quoted(self):
# If: A type is provided that is already quoted
# Then: It shouldn't be quoted
self.assertFalse(pgsmo_utils.templating.needs_quoting('"int"', True))
def test_needs_quoting_numeric_value(self):
# If: A value is numeric (? starts with 0-9)
# Then: It should be quoted
self.assertTrue(pgsmo_utils.templating.needs_quoting('2000', False))
def test_needs_quoting_non_alphanumeric(self):
# If: A value is not lowercase alphanumeric
# Then: It should be quoted
self.assertTrue(pgsmo_utils.templating.needs_quoting('Something+Else', False))
self.assertTrue(pgsmo_utils.templating.needs_quoting('Something+Else', True))

179
tests/pgsmo_tests/utils.py Normal file
Просмотреть файл

@ -0,0 +1,179 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from typing import Callable, List, Optional, Tuple
import unittest
import unittest.mock as mock
from psycopg2 import DatabaseError
from psycopg2.extensions import Column, connection
from pgsmo.objects.node_object import NodeObject, NodeCollection
from pgsmo.utils.querying import ServerConnection
# MOCK CONNECTION ##########################################################
def get_mock_columns(col_count: int) -> List[Column]:
return [Column(f'column{i}', None, 10, 10, None, None, True) for i in range(0, col_count+1)]
def get_named_mock_columns(col_names: List[str]) -> List[Column]:
return [Column(x, None, 10, 10, None, None, True) for x in col_names]
def get_mock_results(col_count: int=5, row_count: int=5) -> Tuple[List[Column], List[dict]]:
rows = []
cols = get_mock_columns(col_count)
for i in range(0, len(cols)):
# Add the column to the rows
for j in range(0, row_count+1):
if len(rows) >= j:
rows.append({})
rows[j][cols[i].name] = f'value{j}.{i}'
return cols, rows
class MockCursor:
def __init__(self, results: Optional[Tuple[List[Column], List[dict]]], throw_on_execute=False):
# Setup the results, that will change value once the cursor is executed
self.description = None
self.rowcount = None
self._results = results
self._throw_on_execute = throw_on_execute
# Define iterator state
self._has_been_read = False
self._iter_index = 0
# Define mocks for the methods of the cursor
self.execute = mock.MagicMock(side_effect=self._execute)
self.close = mock.MagicMock()
def __iter__(self):
# If we haven't read yet, raise an error
# Or if we have read but we're past the end of the list, raise an error
if not self._has_been_read or self._iter_index > len(self._results[1]):
raise StopIteration
# From python 3.6+ this dicts preserve order, so this isn't an issue
yield list(self._results[1][self._iter_index].values())
self._iter_index += 1
def _execute(self, query, params):
# Raise error if that was expected, otherwise set the output
if self._throw_on_execute:
raise DatabaseError()
self.description = self._results[0]
self.rowcount = len(self._results[1])
self._has_been_read = True
class MockConnection(connection):
def __init__(
self,
cur: Optional[MockCursor],
version: str='90602',
name: str='postgres',
host: str='localhost',
port: str='25565'):
# Setup the properties
self._server_version = version
# Setup mocks for the connection
self.close = mock.MagicMock()
self.cursor = mock.MagicMock(return_value=cur)
dsn_params = {'dbname': name, 'host': host, 'port': port}
self.get_dsn_parameters = mock.MagicMock(return_value=dsn_params)
@property
def server_version(self):
return self._server_version
# OBJECT TEST HELPERS ######################################################
def get_nodes_for_parent_base(class_, data: dict, get_nodes_for_parent: Callable, validate_obj: Callable):
# Setup: Create a mockup server connection
mock_cur = MockCursor((get_named_mock_columns(list(data.keys())), [data for i in range(0, 6)]))
mock_conn = ServerConnection(MockConnection(mock_cur))
# ... Create a mock template renderer
mock_render = mock.MagicMock(return_value="SQL")
mock_template_path = mock.MagicMock(return_value="path")
# ... Create a testcase for calling asserts with
test_case = unittest.TestCase('__init__')
# ... Patch the templating
with mock.patch(class_.__module__ + '.templating.render_template', mock_render, create=True):
with mock.patch(class_.__module__ + '.templating.get_template_path', mock_template_path, create=True):
# If: ask for a collection of nodes
output = get_nodes_for_parent(mock_conn)
# Then:
# ... The output should be a list of objects
test_case.assertIsInstance(output, list)
for obj in output:
# ... The object must be the class that was passed in
test_case.assertIsInstance(obj, class_)
# ... Call the validator on the object
validate_obj(obj, mock_conn)
def from_node_query_base(class_, data: dict, validate_obj: Callable):
# If: I create a new object from a node row
mock_conn = ServerConnection(MockConnection(None))
obj = class_._from_node_query(mock_conn, **data)
# Then:
# ... The returned object must be an instance of the class
test_case = unittest.TestCase('__init__')
test_case.assertIsInstance(obj, NodeObject)
test_case.assertIsInstance(obj, class_)
# ... Call the validation function
validate_obj(obj, mock_conn)
def init_base(class_, props: List[str], collections: List[str], custom_validation: Callable=None):
# If: I create an instance of the provided class
mock_conn = ServerConnection(MockConnection(None))
name = 'test'
obj = class_(mock_conn, name)
validate_init(class_, name, mock_conn, obj, props, collections, custom_validation)
def validate_init(class_, name, mock_conn, obj, props: List[str], collections: List[str],
custom_validation: Callable=None):
# Then:
# ... The object must be of the type that was provided
test_case = unittest.TestCase('__init__')
test_case.assertIsInstance(obj, NodeObject)
test_case.assertIsInstance(obj, class_)
# ... The NodeObject basic properties should be set up appropriately
test_case.assertIs(obj._conn, mock_conn)
test_case.assertEqual(obj._name, name)
test_case.assertEqual(obj.name, name)
test_case.assertIsNone(obj._oid)
test_case.assertIsNone(obj.oid)
# ... The rest of the properties should be none
for prop in props:
test_case.assertIsNone(getattr(obj, prop))
# ... The child properties should be assigned to node collections
for coll in collections:
test_case.assertIsInstance(getattr(obj, coll), NodeCollection)
# ... Run the custom validation
if custom_validation is not None:
custom_validation(obj)