Adding initial unit tests for PGSMO (#49)
* Adding unit tests for querying helper Renaming ConnectionWrapper to ServerConnection (a la SMO) Making execute methods be methods on the ServerConnection * Adding tests for templating, except pgAdmin code * Adding tests for ScanKeywordExtraLookup * Some more tests... * Fixing more merge issues * Fixing nonsense with conflicting package names * Adding unittests for column class * Adding unittests for database objects * Adding unittests for role objects * Adding unittests for schema objects * Refactoring out init test which will be reused * Refactoring out init tests to make it easier to reuse the validation code * Adding tests for table class * Adding tests for tablespace class * Adding tests for view class * Adding tests for server class * Flake8 stuff and things * Fixing bug from merge... * Changes as per PR comments
This commit is contained in:
Родитель
c63502f826
Коммит
3cee508796
|
@ -6,18 +6,19 @@
|
|||
from typing import List, Optional
|
||||
|
||||
import pgsmo.objects.node_object as node
|
||||
import pgsmo.utils as utils
|
||||
import pgsmo.utils.querying as querying
|
||||
import pgsmo.utils.templating as templating
|
||||
|
||||
TEMPLATE_ROOT = utils.templating.get_template_root(__file__, 'templates')
|
||||
TEMPLATE_ROOT = templating.get_template_root(__file__, 'templates')
|
||||
|
||||
|
||||
class Column(node.NodeObject):
|
||||
@classmethod
|
||||
def get_nodes_for_parent(cls, conn: utils.querying.ConnectionWrapper, tid: int) -> List['Column']:
|
||||
def get_nodes_for_parent(cls, conn: querying.ServerConnection, tid: int) -> List['Column']:
|
||||
return node.get_nodes(conn, TEMPLATE_ROOT, cls._from_node_query, tid=tid)
|
||||
|
||||
@classmethod
|
||||
def _from_node_query(cls, conn: utils.querying.ConnectionWrapper, **kwargs) -> 'Column':
|
||||
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs) -> 'Column':
|
||||
"""
|
||||
Creates a new Column object based on the the results from the column nodes query
|
||||
:param conn: Connection used to execute the column nodes query
|
||||
|
@ -37,7 +38,7 @@ class Column(node.NodeObject):
|
|||
|
||||
return col
|
||||
|
||||
def __init__(self, conn: utils.querying.ConnectionWrapper, name: str, datatype: str):
|
||||
def __init__(self, conn: querying.ServerConnection, name: str, datatype: str):
|
||||
"""
|
||||
Initializes a new instance of a Column
|
||||
:param conn: Connection to the server/database that this object will belong to
|
||||
|
@ -63,11 +64,3 @@ class Column(node.NodeObject):
|
|||
@property
|
||||
def not_null(self) -> Optional[bool]:
|
||||
return self._not_null
|
||||
|
||||
# METHODS ##############################################################
|
||||
def refresh(self):
|
||||
self._fetch_properties()
|
||||
|
||||
# IMPLEMENTATION DETAILS ###############################################
|
||||
def _fetch_properties(self):
|
||||
pass
|
||||
|
|
|
@ -7,18 +7,19 @@ from typing import List, Optional # noqa
|
|||
|
||||
import pgsmo.objects.node_object as node
|
||||
from pgsmo.objects.schema.schema import Schema
|
||||
import pgsmo.utils as utils
|
||||
import pgsmo.utils.querying as querying
|
||||
import pgsmo.utils.templating as templating
|
||||
|
||||
TEMPLATE_ROOT = utils.templating.get_template_root(__file__, 'templates')
|
||||
TEMPLATE_ROOT = templating.get_template_root(__file__, 'templates')
|
||||
|
||||
|
||||
class Database(node.NodeObject):
|
||||
@classmethod
|
||||
def get_nodes_for_parent(cls, conn: utils.querying.ConnectionWrapper) -> List['Database']:
|
||||
def get_nodes_for_parent(cls, conn: querying.ServerConnection) -> List['Database']:
|
||||
return node.get_nodes(conn, TEMPLATE_ROOT, cls._from_node_query, last_system_oid=0)
|
||||
|
||||
@classmethod
|
||||
def _from_node_query(cls, conn: utils.querying.ConnectionWrapper, **kwargs) -> 'Database':
|
||||
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs) -> 'Database':
|
||||
"""
|
||||
Creates a new Database object based on the results from a query to lookup databases
|
||||
:param conn: Connection used to generate the db info query
|
||||
|
@ -34,7 +35,6 @@ class Database(node.NodeObject):
|
|||
"""
|
||||
db = cls(conn, kwargs['name'])
|
||||
db._oid = kwargs['did']
|
||||
db._is_connected = kwargs['name'] == conn.dsn_parameters.get('dbname')
|
||||
db._tablespace = kwargs['spcname']
|
||||
db._allow_conn = kwargs['datallowconn']
|
||||
db._can_create = kwargs['cancreate']
|
||||
|
@ -42,13 +42,13 @@ class Database(node.NodeObject):
|
|||
|
||||
return db
|
||||
|
||||
def __init__(self, conn: utils.querying.ConnectionWrapper, name: str):
|
||||
def __init__(self, conn: querying.ServerConnection, name: str):
|
||||
"""
|
||||
Initializes a new instance of a database
|
||||
:param name: Name of the database
|
||||
"""
|
||||
super(Database, self).__init__(conn, name)
|
||||
self._is_connected: bool = False
|
||||
self._is_connected: bool = conn.dsn_parameters.get('dbname') == name
|
||||
|
||||
# Declare the optional parameters
|
||||
self._tablespace: Optional[str] = None
|
||||
|
@ -57,7 +57,9 @@ class Database(node.NodeObject):
|
|||
self._owner_oid: Optional[int] = None
|
||||
|
||||
# Declare the child items
|
||||
self._schemas: node.NodeCollection = node.NodeCollection(lambda: Schema.get_nodes_for_parent(self._conn))
|
||||
self._schemas: Optional[node.NodeCollection] = None if not self._is_connected else node.NodeCollection(
|
||||
lambda: Schema.get_nodes_for_parent(conn)
|
||||
)
|
||||
|
||||
# PROPERTIES ###########################################################
|
||||
# TODO: Create setters for optional values
|
||||
|
@ -80,11 +82,6 @@ class Database(node.NodeObject):
|
|||
return self._schemas
|
||||
|
||||
# METHODS ##############################################################
|
||||
|
||||
def refresh(self):
|
||||
self._fetch_properties()
|
||||
self._schemas.reset()
|
||||
|
||||
def create(self):
|
||||
pass
|
||||
|
||||
|
@ -94,6 +91,7 @@ class Database(node.NodeObject):
|
|||
def delete(self):
|
||||
pass
|
||||
|
||||
# IMPLEMENTATION DETAILS ###############################################
|
||||
def _fetch_properties(self):
|
||||
pass
|
||||
def refresh(self):
|
||||
"""Resets the internal collection of child objects"""
|
||||
if self._schemas is not None:
|
||||
self._schemas.reset()
|
||||
|
|
|
@ -14,12 +14,12 @@ import pgsmo.utils.querying as querying
|
|||
class NodeObject:
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _from_node_query(cls, conn: querying.ConnectionWrapper, **kwargs):
|
||||
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, conn: querying.ConnectionWrapper, name: str):
|
||||
def __init__(self, conn: querying.ServerConnection, name: str):
|
||||
# Define the state of the object
|
||||
self._conn: querying.ConnectionWrapper = conn
|
||||
self._conn: querying.ServerConnection = conn
|
||||
|
||||
# Declare node basic properties
|
||||
self._name: str = name
|
||||
|
@ -92,9 +92,9 @@ class NodeCollection:
|
|||
T = TypeVar('T')
|
||||
|
||||
|
||||
def get_nodes(conn: querying.ConnectionWrapper,
|
||||
def get_nodes(conn: querying.ServerConnection,
|
||||
template_root: str,
|
||||
generator: Callable[[type, querying.ConnectionWrapper, Dict[str, any]], T],
|
||||
generator: Callable[[type, querying.ServerConnection, Dict[str, any]], T],
|
||||
**kwargs) -> List[T]:
|
||||
"""
|
||||
Renders and executes nodes.sql for the given database version to generate a list of NodeObjects
|
||||
|
@ -108,6 +108,6 @@ def get_nodes(conn: querying.ConnectionWrapper,
|
|||
templating.get_template_path(template_root, 'nodes.sql', conn.version),
|
||||
**kwargs
|
||||
)
|
||||
cols, rows = querying.execute_dict(conn, sql)
|
||||
cols, rows = conn.execute_dict(sql)
|
||||
|
||||
return [generator(conn, **row) for row in rows]
|
||||
|
|
|
@ -6,15 +6,16 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from pgsmo.objects.node_object import NodeObject, get_nodes
|
||||
import pgsmo.utils as utils
|
||||
import pgsmo.utils.querying as querying
|
||||
import pgsmo.utils.templating as templating
|
||||
|
||||
|
||||
TEMPLATE_ROOT = utils.templating.get_template_root(__file__, 'templates')
|
||||
TEMPLATE_ROOT = templating.get_template_root(__file__, 'templates')
|
||||
|
||||
|
||||
class Role(NodeObject):
|
||||
@classmethod
|
||||
def get_nodes_for_parent(cls, conn: utils.querying.ConnectionWrapper) -> List['Role']:
|
||||
def get_nodes_for_parent(cls, conn: querying.ServerConnection) -> List['Role']:
|
||||
"""
|
||||
Generates a list of roles for a given server. Intended to only be called by a Server object
|
||||
:param conn: Connection to use to look up the roles for the server
|
||||
|
@ -23,7 +24,7 @@ class Role(NodeObject):
|
|||
return get_nodes(conn, TEMPLATE_ROOT, cls._from_node_query)
|
||||
|
||||
@classmethod
|
||||
def _from_node_query(cls, conn: utils.querying.ConnectionWrapper, **kwargs) -> 'Role':
|
||||
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs) -> 'Role':
|
||||
"""
|
||||
Creates a Role object from the result of a role node query
|
||||
:param conn: Connection that executed the role node query
|
||||
|
@ -39,7 +40,7 @@ class Role(NodeObject):
|
|||
|
||||
return role
|
||||
|
||||
def __init__(self, conn: utils.querying.ConnectionWrapper, name: str):
|
||||
def __init__(self, conn: querying.ServerConnection, name: str):
|
||||
"""
|
||||
Initializes internal state of a Role object
|
||||
:param conn: Connection that executed the role node query
|
||||
|
|
|
@ -9,20 +9,21 @@ from typing import List, Optional
|
|||
import pgsmo.objects.node_object as node
|
||||
from pgsmo.objects.table.table import Table
|
||||
from pgsmo.objects.view.view import View
|
||||
import pgsmo.utils as utils
|
||||
import pgsmo.utils.querying as querying
|
||||
import pgsmo.utils.templating as templating
|
||||
|
||||
|
||||
TEMPLATE_ROOT = utils.templating.get_template_root(__file__, 'templates')
|
||||
TEMPLATE_ROOT = templating.get_template_root(__file__, 'templates')
|
||||
|
||||
|
||||
class Schema(node.NodeObject):
|
||||
@classmethod
|
||||
def get_nodes_for_parent(cls, conn: utils.querying.ConnectionWrapper) -> List['Schema']:
|
||||
def get_nodes_for_parent(cls, conn: querying.ServerConnection) -> List['Schema']:
|
||||
type_template_root = path.join(TEMPLATE_ROOT, conn.server_type)
|
||||
return node.get_nodes(conn, type_template_root, cls._from_node_query)
|
||||
|
||||
@classmethod
|
||||
def _from_node_query(cls, conn: utils.querying.ConnectionWrapper, **kwargs) -> 'Schema':
|
||||
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs) -> 'Schema':
|
||||
"""
|
||||
Creates an instance of a schema object from the results of a nodes query
|
||||
:param conn: The connection used to execute the nodes query
|
||||
|
@ -41,7 +42,7 @@ class Schema(node.NodeObject):
|
|||
|
||||
return schema
|
||||
|
||||
def __init__(self, conn: utils.querying.ConnectionWrapper, name: str):
|
||||
def __init__(self, conn: querying.ServerConnection, name: str):
|
||||
super(Schema, self).__init__(conn, name)
|
||||
|
||||
# Declare the optional parameters
|
||||
|
@ -53,7 +54,7 @@ class Schema(node.NodeObject):
|
|||
lambda: Table.get_nodes_for_parent(self._conn, self._oid)
|
||||
)
|
||||
self._views: node.NodeCollection = node.NodeCollection(
|
||||
lambda: View.get_nodes_for_parent(self._conn, self.oid)
|
||||
lambda: View.get_nodes_for_parent(self._conn, self._oid)
|
||||
)
|
||||
|
||||
# PROPERTIES ###########################################################
|
||||
|
@ -76,5 +77,6 @@ class Schema(node.NodeObject):
|
|||
|
||||
# METHODS ##############################################################
|
||||
def refresh(self) -> None:
|
||||
self._tables = Table.get_tables_for_schema(self._conn, self._oid)
|
||||
self._views = View.get_views_for_schema(self._conn, self._oid)
|
||||
"""Resets the internal collections of child objects"""
|
||||
self._tables.reset()
|
||||
self._views.reset()
|
||||
|
|
|
@ -25,10 +25,10 @@ class Server:
|
|||
:param conn: psycopg2 connection
|
||||
"""
|
||||
# Everything we know about the server will be based on the connection
|
||||
self._conn = utils.querying.ConnectionWrapper(conn)
|
||||
self._conn = utils.querying.ServerConnection(conn)
|
||||
|
||||
# Declare the server properties
|
||||
props = self._conn.connection.get_dsn_parameters()
|
||||
props = self._conn.dsn_parameters
|
||||
self._host: str = props['host']
|
||||
self._port: int = int(props['port'])
|
||||
self._maintenance_db: str = props['dbname']
|
||||
|
@ -55,8 +55,8 @@ class Server:
|
|||
return self._host
|
||||
|
||||
@property
|
||||
def in_recovery(self) -> bool:
|
||||
"""Whether or not the server is in recovery mode"""
|
||||
def in_recovery(self) -> Optional[bool]:
|
||||
"""Whether or not the server is in recovery mode. If None, value was not loaded from server"""
|
||||
return self._in_recovery
|
||||
|
||||
@property
|
||||
|
@ -75,8 +75,8 @@ class Server:
|
|||
return self._conn.version
|
||||
|
||||
@property
|
||||
def wal_paused(self) -> bool:
|
||||
"""Whether or not the Write-Ahead Log (WAL) is paused"""
|
||||
def wal_paused(self) -> Optional[bool]:
|
||||
"""Whether or not the Write-Ahead Log (WAL) is paused. If None, value was not loaded from server"""
|
||||
return self._wal_paused
|
||||
|
||||
# -CHILD OBJECTS #######################################################
|
||||
|
@ -98,15 +98,17 @@ class Server:
|
|||
# METHODS ##############################################################
|
||||
|
||||
# IMPLEMENTATION DETAILS ###############################################
|
||||
def _fetch_recovery_state(self) -> None:
|
||||
recovery_check_sql = utils.templating.render_template(
|
||||
utils.templating.get_template_path(TEMPLATE_ROOT, 'check_recovery.sql', self._conn.version)
|
||||
)
|
||||
|
||||
cols, rows = utils.querying.execute_dict(self._conn, recovery_check_sql)
|
||||
if len(rows) > 0:
|
||||
self._in_recovery = rows[0]['inrecovery']
|
||||
self._wal_paused = rows[0]['isreplaypaused']
|
||||
else:
|
||||
self._in_recovery = None
|
||||
self._wal_paused = None
|
||||
# Commenting out until support for extended properties is added
|
||||
# See https://github.com/Microsoft/carbon/issues/1342
|
||||
# def _fetch_recovery_state(self) -> None:
|
||||
# recovery_check_sql = utils.templating.render_template(
|
||||
# utils.templating.get_template_path(TEMPLATE_ROOT, 'check_recovery.sql', self._conn.version)
|
||||
# )
|
||||
#
|
||||
# cols, rows = self._conn.execute_dict(recovery_check_sql)
|
||||
# if len(rows) > 0:
|
||||
# self._in_recovery = rows[0]['inrecovery']
|
||||
# self._wal_paused = rows[0]['isreplaypaused']
|
||||
# else:
|
||||
# self._in_recovery = None
|
||||
# self._wal_paused = None
|
||||
|
|
|
@ -7,19 +7,20 @@ from typing import List
|
|||
|
||||
from pgsmo.objects.column.column import Column
|
||||
import pgsmo.objects.node_object as node
|
||||
import pgsmo.utils as utils
|
||||
import pgsmo.utils.querying as querying
|
||||
import pgsmo.utils.templating as templating
|
||||
|
||||
|
||||
TEMPLATE_ROOT = utils.templating.get_template_root(__file__, 'templates')
|
||||
TEMPLATE_ROOT = templating.get_template_root(__file__, 'templates')
|
||||
|
||||
|
||||
class Table(node.NodeObject):
|
||||
@classmethod
|
||||
def get_nodes_for_parent(cls, conn: utils.querying.ConnectionWrapper, schema_id: int) -> List['Table']:
|
||||
def get_nodes_for_parent(cls, conn: querying.ServerConnection, schema_id: int) -> List['Table']:
|
||||
return node.get_nodes(conn, TEMPLATE_ROOT, cls._from_node_query, scid=schema_id)
|
||||
|
||||
@classmethod
|
||||
def _from_node_query(cls, conn: utils.querying.ConnectionWrapper, **kwargs) -> 'Table':
|
||||
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs) -> 'Table':
|
||||
"""
|
||||
Creates a table instance from the results of a node query
|
||||
:param conn: The connection used to execute the node query
|
||||
|
@ -34,7 +35,7 @@ class Table(node.NodeObject):
|
|||
|
||||
return table
|
||||
|
||||
def __init__(self, conn: utils.querying.ConnectionWrapper, name: str):
|
||||
def __init__(self, conn: querying.ServerConnection, name: str):
|
||||
super(Table, self).__init__(conn, name)
|
||||
|
||||
# Declare child items
|
||||
|
|
|
@ -6,14 +6,15 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from pgsmo.objects.node_object import NodeObject, get_nodes
|
||||
import pgsmo.utils as utils
|
||||
import pgsmo.utils.querying as querying
|
||||
import pgsmo.utils.templating as templating
|
||||
|
||||
TEMPLATE_ROOT = utils.templating.get_template_root(__file__, 'templates')
|
||||
TEMPLATE_ROOT = templating.get_template_root(__file__, 'templates')
|
||||
|
||||
|
||||
class Tablespace(NodeObject):
|
||||
@classmethod
|
||||
def get_nodes_for_parent(cls, conn: utils.querying.ConnectionWrapper) -> List['Tablespace']:
|
||||
def get_nodes_for_parent(cls, conn: querying.ServerConnection) -> List['Tablespace']:
|
||||
"""
|
||||
Creates a list of tablespaces that belong to the server. Intended to be called by Server class
|
||||
:param conn: Connection to a server to use to lookup the information
|
||||
|
@ -22,7 +23,7 @@ class Tablespace(NodeObject):
|
|||
return get_nodes(conn, TEMPLATE_ROOT, cls._from_node_query)
|
||||
|
||||
@classmethod
|
||||
def _from_node_query(cls, conn: utils.querying.ConnectionWrapper, **kwargs) -> 'Tablespace':
|
||||
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs) -> 'Tablespace':
|
||||
"""
|
||||
Creates a tablespace from a row of a nodes query result
|
||||
:param conn: Connection to a server to use to lookup the information
|
||||
|
@ -36,7 +37,7 @@ class Tablespace(NodeObject):
|
|||
|
||||
return tablespace
|
||||
|
||||
def __init__(self, conn: utils.querying.ConnectionWrapper, name: str):
|
||||
def __init__(self, conn: querying.ServerConnection, name: str):
|
||||
"""
|
||||
Initializes internal state of a Role object
|
||||
:param conn: Connection that executed the role node query
|
||||
|
|
|
@ -8,19 +8,20 @@ from typing import List
|
|||
|
||||
from pgsmo.objects.column.column import Column
|
||||
import pgsmo.objects.node_object as node
|
||||
import pgsmo.utils as utils
|
||||
import pgsmo.utils.querying as querying
|
||||
import pgsmo.utils.templating as templating
|
||||
|
||||
TEMPLATE_ROOT = utils.templating.get_template_root(__file__, 'view_templates')
|
||||
TEMPLATE_ROOT = templating.get_template_root(__file__, 'view_templates')
|
||||
|
||||
|
||||
class View(node.NodeObject):
|
||||
@classmethod
|
||||
def get_nodes_for_parent(cls, conn: utils.querying.ConnectionWrapper, scid: int) -> List['View']:
|
||||
def get_nodes_for_parent(cls, conn: querying.ServerConnection, scid: int) -> List['View']:
|
||||
type_template_root = path.join(TEMPLATE_ROOT, conn.server_type)
|
||||
return node.get_nodes(conn, type_template_root, cls._from_node_query, scid=scid)
|
||||
|
||||
@classmethod
|
||||
def _from_node_query(cls, conn: utils.querying.ConnectionWrapper, **kwargs) -> 'View':
|
||||
def _from_node_query(cls, conn: querying.ServerConnection, **kwargs) -> 'View':
|
||||
"""
|
||||
Creates a view object from the results of a node query
|
||||
:param conn: Connection used to execute the nodes query
|
||||
|
@ -35,7 +36,7 @@ class View(node.NodeObject):
|
|||
|
||||
return view
|
||||
|
||||
def __init__(self, conn: utils.querying.ConnectionWrapper, name: str):
|
||||
def __init__(self, conn: querying.ServerConnection, name: str):
|
||||
super(View, self).__init__(conn, name)
|
||||
|
||||
# Declare child items
|
||||
|
|
|
@ -8,8 +8,14 @@ from typing import List, Mapping, Tuple
|
|||
from psycopg2.extensions import Column, connection, cursor # noqa
|
||||
|
||||
|
||||
class ConnectionWrapper:
|
||||
class ServerConnection:
|
||||
"""Wrapper for a psycopg2 connection that makes various properties easier to access"""
|
||||
|
||||
def __init__(self, conn: connection):
|
||||
"""
|
||||
Creates a new connection wrapper. Parses version string
|
||||
:param conn: PsycoPG2 connection object
|
||||
"""
|
||||
self._conn = conn
|
||||
self._dsn_parameters = conn.get_dsn_parameters()
|
||||
|
||||
|
@ -24,40 +30,35 @@ class ConnectionWrapper:
|
|||
# PROPERTIES ###########################################################
|
||||
@property
|
||||
def connection(self) -> connection:
|
||||
"""The psycopg2 connection that this object wraps"""
|
||||
return self._conn
|
||||
|
||||
@property
|
||||
def dsn_parameters(self) -> Mapping[str, str]:
|
||||
"""DSN properties of the underlying connection"""
|
||||
return self._dsn_parameters
|
||||
|
||||
@property
|
||||
def server_type(self) -> str:
|
||||
"""Server type for distinguishing between standard PG and PG supersets"""
|
||||
return 'pg' # TODO: Determine if a server is PPAS or PG
|
||||
|
||||
@property
|
||||
def version(self) -> Tuple[int, int, int]:
|
||||
"""Tuple that splits version string into sensible values"""
|
||||
return self._version
|
||||
|
||||
|
||||
def execute_2d_array(conn: ConnectionWrapper, query: str, params=None) -> Tuple[List[Column], List[list]]:
|
||||
cur: cursor = conn.connection.cursor()
|
||||
|
||||
try:
|
||||
cur.execute(query, params)
|
||||
|
||||
cols: List[Column] = cur.description
|
||||
rows: List[list] = []
|
||||
if cur.rowcount > 0:
|
||||
for row in cur:
|
||||
rows.append(row)
|
||||
|
||||
return cols, rows
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
|
||||
def execute_dict(conn: ConnectionWrapper, query: str, params=None) -> Tuple[List[Column], List[dict]]:
|
||||
cur: cursor = conn.connection.cursor()
|
||||
# METHODS ##############################################################
|
||||
def execute_dict(self, query: str, params=None) -> Tuple[List[Column], List[dict]]:
|
||||
"""
|
||||
Executes a query and returns the results as an ordered list of dictionaries that map column
|
||||
name to value. Columns are returned, as well.
|
||||
:param conn: The connection to use to execute the query
|
||||
:param query: The text of the query to execute
|
||||
:param params: Optional parameters to inject into the query
|
||||
:return: A list of column objects and a list of rows, which are formatted as dicts.
|
||||
"""
|
||||
cur: cursor = self._conn.cursor()
|
||||
|
||||
try:
|
||||
cur.execute(query, params)
|
||||
|
|
|
@ -12,7 +12,7 @@ from psycopg2.extensions import adapt
|
|||
|
||||
TEMPLATE_ENVIRONMENTS: Dict[str, Environment] = {}
|
||||
TEMPLATE_FOLDER_REGEX = re.compile('(\d+)\.(\d+)(?:_(\w+))?$')
|
||||
TEMPLATE_NON_DISCOVERED_NAMES: List[str] = ['macros']
|
||||
TEMPLATE_SKIPPED_FOLDERS: List[str] = ['macros']
|
||||
|
||||
|
||||
def get_template_root(file_path: str, template_directory: str) -> str:
|
||||
|
@ -29,7 +29,7 @@ def get_template_path(template_root: str, template_name: str, server_version: Tu
|
|||
"""
|
||||
# Step 1) Get the list of folders in the template folder that contains the
|
||||
# Step 1.1) Get the list of folders in the template root folder
|
||||
all_folders: List[str] = [x[0] for x in os.walk(template_root)]
|
||||
all_folders: List[str] = [os.path.normpath(x[0]) for x in os.walk(template_root)]
|
||||
|
||||
# Step 1.2) Filter out the folders that don't contain the target template
|
||||
containing_folders: List[str] = [x for x in all_folders if template_name in next(os.walk(x))[2]]
|
||||
|
@ -40,16 +40,17 @@ def get_template_path(template_root: str, template_name: str, server_version: Tu
|
|||
# Step 2) Iterate over the list of directories and check if the server version fits the bill
|
||||
for folder in containing_folders:
|
||||
# Skip over non-included folders
|
||||
if os.path.dirname(folder) in TEMPLATE_NON_DISCOVERED_NAMES:
|
||||
if os.path.basename(folder) in TEMPLATE_SKIPPED_FOLDERS:
|
||||
continue
|
||||
|
||||
# If we are at the default, we are at the end of the list, so it is the only valid match
|
||||
if folder.endswith('/+default'):
|
||||
if folder.endswith(os.sep + '+default'):
|
||||
return os.path.join(folder, template_name)
|
||||
|
||||
# Process the folder name with the regex
|
||||
match = TEMPLATE_FOLDER_REGEX.search(folder)
|
||||
if not match:
|
||||
# This indicates a serious bug that shouldn't occur in production code, so this needn't be localized.
|
||||
raise ValueError(f'Templates folder {template_root} contains improperly formatted folder name {folder}')
|
||||
captures = match.groups()
|
||||
major = int(captures[0])
|
||||
|
@ -67,6 +68,7 @@ def get_template_path(template_root: str, template_name: str, server_version: Tu
|
|||
# TODO: Modifier is minus
|
||||
|
||||
# If we make it to here, the template doesn't exist.
|
||||
# This indicates a serious bug that shouldn't occur in production code, so this doesn't need to be localized.
|
||||
raise ValueError(f'Template folder {template_root} does not contain {template_name}')
|
||||
|
||||
|
||||
|
@ -75,6 +77,7 @@ def render_template(template_path: str, **context) -> str:
|
|||
Renders a template from the template folder with the given context.
|
||||
:param template_path: the path to the template to be rendered
|
||||
:param context: the variables that should be available in the context of the template.
|
||||
:return: The template rendered with the provided context
|
||||
"""
|
||||
path, filename = os.path.split(template_path)
|
||||
if path not in TEMPLATE_ENVIRONMENTS:
|
||||
|
@ -85,13 +88,13 @@ def render_template(template_path: str, **context) -> str:
|
|||
|
||||
# Create the environment and add the basic filters
|
||||
new_env: Environment = Environment(loader=loader)
|
||||
new_env.filters['qtLiteral'] = qtLiteral
|
||||
new_env.filters['qtIdent'] = qtIdent
|
||||
new_env.filters['qtTypeIdent'] = qtTypeIdent
|
||||
new_env.filters['qtLiteral'] = qt_literal
|
||||
new_env.filters['qtIdent'] = qt_ident
|
||||
new_env.filters['qtTypeIdent'] = qt_type_ident
|
||||
|
||||
TEMPLATE_ENVIRONMENTS[template_path] = new_env
|
||||
TEMPLATE_ENVIRONMENTS[path] = new_env
|
||||
|
||||
env = TEMPLATE_ENVIRONMENTS[template_path]
|
||||
env = TEMPLATE_ENVIRONMENTS[path]
|
||||
return env.get_template(filename).render(context)
|
||||
|
||||
|
||||
|
@ -101,6 +104,7 @@ def render_template_string(source, **context):
|
|||
autoescaped.
|
||||
:param source: the source code of the template to be rendered
|
||||
:param context: the variables that should be available in the context of the template.
|
||||
:return: The template rendered with the provided context
|
||||
"""
|
||||
template = Template(source)
|
||||
return template.render(context)
|
||||
|
@ -115,7 +119,7 @@ def render_template_string(source, **context):
|
|||
##########################################################################
|
||||
|
||||
|
||||
def qtLiteral(value):
|
||||
def qt_literal(value):
|
||||
adapted = adapt(value)
|
||||
|
||||
# Not all adapted objects have encoding
|
||||
|
@ -133,7 +137,7 @@ def qtLiteral(value):
|
|||
return res
|
||||
|
||||
|
||||
def qtTypeIdent(conn, *args):
|
||||
def qt_type_ident(conn, *args):
|
||||
# We're not using the conn object at the moment, but - we will modify the
|
||||
# logic to use the server version specific keywords later.
|
||||
res = None
|
||||
|
@ -144,7 +148,7 @@ def qtTypeIdent(conn, *args):
|
|||
continue
|
||||
value = val
|
||||
|
||||
if needsQuoting(val, True):
|
||||
if needs_quoting(val, True):
|
||||
value = value.replace("\"", "\"\"")
|
||||
value = "\"" + value + "\""
|
||||
|
||||
|
@ -153,7 +157,7 @@ def qtTypeIdent(conn, *args):
|
|||
return res
|
||||
|
||||
|
||||
def qtIdent(conn, *args):
|
||||
def qt_ident(conn, *args):
|
||||
# We're not using the conn object at the moment, but - we will modify the
|
||||
# logic to use the server version specific keywords later.
|
||||
res = None
|
||||
|
@ -161,14 +165,14 @@ def qtIdent(conn, *args):
|
|||
|
||||
for val in args:
|
||||
if type(val) == list:
|
||||
return map(lambda w: qtIdent(conn, w), val)
|
||||
return map(lambda w: qt_ident(conn, w), val)
|
||||
|
||||
if len(val) == 0:
|
||||
continue
|
||||
|
||||
value = val
|
||||
|
||||
if needsQuoting(val, False):
|
||||
if needs_quoting(val, False):
|
||||
value = value.replace("\"", "\"\"")
|
||||
value = "\"" + value + "\""
|
||||
|
||||
|
@ -177,18 +181,18 @@ def qtIdent(conn, *args):
|
|||
return res
|
||||
|
||||
|
||||
def needsQuoting(key, forTypes):
|
||||
def needs_quoting(key, for_types):
|
||||
value = key
|
||||
valNoArray = value
|
||||
val_no_array = value
|
||||
|
||||
# check if the string is number or not
|
||||
if isinstance(value, int):
|
||||
return True
|
||||
# certain types should not be quoted even though it contains a space. Evilness.
|
||||
elif forTypes and value[-2:] == u"[]":
|
||||
valNoArray = value[:-2]
|
||||
elif for_types and value[-2:] == u"[]":
|
||||
val_no_array = value[:-2]
|
||||
|
||||
if forTypes and valNoArray.lower() in [
|
||||
if for_types and val_no_array.lower() in [
|
||||
u'bit varying',
|
||||
u'"char"',
|
||||
u'character varying',
|
||||
|
@ -203,21 +207,21 @@ def needsQuoting(key, forTypes):
|
|||
return False
|
||||
|
||||
# If already quoted?, If yes then do not quote again
|
||||
if forTypes and valNoArray:
|
||||
if valNoArray.startswith('"') \
|
||||
or valNoArray.endswith('"'):
|
||||
if for_types and val_no_array:
|
||||
if val_no_array.startswith('"') \
|
||||
or val_no_array.endswith('"'):
|
||||
return False
|
||||
|
||||
if u'0' <= valNoArray[0] <= u'9':
|
||||
if u'0' <= val_no_array[0] <= u'9':
|
||||
return True
|
||||
|
||||
for c in valNoArray:
|
||||
for c in val_no_array:
|
||||
if (not (u'a' <= c <= u'z') and c != u'_' and
|
||||
not (u'0' <= c <= u'9')):
|
||||
return True
|
||||
|
||||
# check string is keywaord or not
|
||||
category = ScanKeywordExtraLookup(value)
|
||||
category = scan_keyword_extra_lookup(value)
|
||||
|
||||
if category is None:
|
||||
return False
|
||||
|
@ -227,18 +231,21 @@ def needsQuoting(key, forTypes):
|
|||
return False
|
||||
|
||||
# COL_NAME_KEYWORD
|
||||
if forTypes and category == 1:
|
||||
if for_types and category == 1:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def ScanKeywordExtraLookup(key):
|
||||
# UNRESERVED_KEYWORD 0
|
||||
# COL_NAME_KEYWORD 1
|
||||
# TYPE_FUNC_NAME_KEYWORD 2
|
||||
# RESERVED_KEYWORD 3
|
||||
extraKeywords = {
|
||||
def scan_keyword_extra_lookup(key):
|
||||
return (key in _EXTRA_KEYWORDS and _EXTRA_KEYWORDS[key]) or _KEYWORD_DICT.get(key)
|
||||
|
||||
|
||||
# UNRESERVED_KEYWORD 0
|
||||
# COL_NAME_KEYWORD 1
|
||||
# TYPE_FUNC_NAME_KEYWORD 2
|
||||
# RESERVED_KEYWORD 3
|
||||
_EXTRA_KEYWORDS = {
|
||||
'connect': 3,
|
||||
'convert': 3,
|
||||
'distributed': 0,
|
||||
|
@ -260,14 +267,10 @@ def ScanKeywordExtraLookup(key):
|
|||
'tinyint': 3,
|
||||
'tinytext': 3,
|
||||
'varchar2': 3
|
||||
}
|
||||
}
|
||||
|
||||
return (key in extraKeywords and extraKeywords[key]) or ScanKeyword(key)
|
||||
|
||||
|
||||
def ScanKeyword(key):
|
||||
# ScanKeyword function for PostgreSQL 9.5rc1
|
||||
keywordDict = {
|
||||
# ScanKeyword function for PostgreSQL 9.5rc1
|
||||
_KEYWORD_DICT = {
|
||||
"abort": 0, "absolute": 0, "access": 0, "action": 0, "add": 0, "admin": 0, "after": 0, "aggregate": 0, "all": 3,
|
||||
"also": 0, "alter": 0, "always": 0, "analyze": 3, "and": 3, "any": 3, "array": 3, "as": 3, "asc": 3,
|
||||
"assertion": 0, "assignment": 0, "asymmetric": 3, "at": 0, "attribute": 0, "authorization": 2, "backward": 0,
|
||||
|
@ -321,5 +324,4 @@ def ScanKeyword(key):
|
|||
"with": 3, "within": 0, "without": 0, "work": 0, "wrapper": 0, "write": 0, "xml": 0, "xmlattributes": 1,
|
||||
"xmlconcat": 1, "xmlelement": 1, "xmlexists": 1, "xmlforest": 1, "xmlparse": 1, "xmlpi": 1, "xmlroot": 1,
|
||||
"xmlserialize": 1, "year": 0, "yes": 0, "zone": 0
|
||||
}
|
||||
return keywordDict.get(key)
|
||||
}
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
{{foo}}
|
|
@ -125,23 +125,22 @@ class TestNodeObject(unittest.TestCase):
|
|||
|
||||
def test_get_nodes(self):
|
||||
# Setup:
|
||||
# ... Create a mockup of a connection wrapper
|
||||
# ... Create a mockup of a server connection with a mock executor
|
||||
version = (1, 1, 1)
|
||||
|
||||
class MockConn:
|
||||
def __init__(self):
|
||||
self.version = version
|
||||
|
||||
mock_objs = [{'name': 'abc', 'oid': 123}, {'name': 'def', 'oid': 456}]
|
||||
mock_executor = mock.MagicMock(return_value=([{}, {}], mock_objs))
|
||||
mock_conn = MockConn()
|
||||
mock_conn.execute_dict = mock_executor
|
||||
|
||||
# ... Create a mock template renderer
|
||||
mock_render = mock.MagicMock(return_value="SQL")
|
||||
mock_template_path = mock.MagicMock(return_value="path")
|
||||
|
||||
# ... Create a mock query executor
|
||||
mock_objs = [{'name': 'abc', 'oid': 123}, {'name': 'def', 'oid': 456}]
|
||||
mock_executor = mock.MagicMock(return_value=([{}, {}], mock_objs))
|
||||
|
||||
# ... Create a mock generator
|
||||
mock_output = {}
|
||||
mock_generator = mock.MagicMock(return_value=mock_output)
|
||||
|
@ -149,7 +148,6 @@ class TestNodeObject(unittest.TestCase):
|
|||
# ... Do the patching
|
||||
with mock.patch('pgsmo.objects.node_object.templating.render_template', mock_render, create=True):
|
||||
with mock.patch('pgsmo.objects.node_object.templating.get_template_path', mock_template_path, create=True):
|
||||
with mock.patch('pgsmo.objects.node_object.querying.execute_dict', mock_executor, create=True):
|
||||
# If: I ask for a collection of nodes
|
||||
kwargs = {'arg1': 'something'}
|
||||
nodes = node.get_nodes(mock_conn, 'root', mock_generator, **kwargs)
|
||||
|
@ -162,7 +160,7 @@ class TestNodeObject(unittest.TestCase):
|
|||
mock_render.assert_called_once_with('path', **kwargs)
|
||||
|
||||
# ... A query should have been executed
|
||||
mock_executor.assert_called_once_with(mock_conn, 'SQL')
|
||||
mock_executor.assert_called_once_with('SQL')
|
||||
|
||||
# ... The generator should have been called twice with different object props
|
||||
mock_generator.assert_any_call(mock_conn, **mock_objs[0])
|
|
@ -0,0 +1,65 @@
|
|||
# --------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
import unittest
|
||||
|
||||
from pgsmo.objects.column.column import Column
|
||||
from pgsmo.utils.querying import ServerConnection
|
||||
import tests.pgsmo_tests.utils as utils
|
||||
|
||||
COLUMN_ROW = {
|
||||
'name': 'abc',
|
||||
'datatype': 'character',
|
||||
'oid': 123,
|
||||
'has_default_val': True,
|
||||
'not_null': True
|
||||
}
|
||||
|
||||
|
||||
class TestColumn(unittest.TestCase):
|
||||
# CONSTRUCTION TESTS ###################################################
|
||||
def test_init(self):
|
||||
props = [
|
||||
'has_default_value', '_has_default_value',
|
||||
'not_null', '_not_null'
|
||||
]
|
||||
colls = []
|
||||
name = 'column'
|
||||
datatype = 'character'
|
||||
mock_conn = ServerConnection(utils.MockConnection(None))
|
||||
obj = Column(mock_conn, name, datatype)
|
||||
utils.validate_init(
|
||||
Column, name, mock_conn, obj, props, colls,
|
||||
lambda obj: self._validate_init(obj, datatype)
|
||||
)
|
||||
|
||||
def test_from_node_query(self):
|
||||
utils.from_node_query_base(Column, COLUMN_ROW, self._validate_column)
|
||||
|
||||
def test_get_nodes_for_parent(self):
|
||||
# Use the test helper for this method
|
||||
get_nodes_for_parent = (lambda conn: Column.get_nodes_for_parent(conn, 10))
|
||||
utils.get_nodes_for_parent_base(Column, COLUMN_ROW, get_nodes_for_parent, self._validate_column)
|
||||
|
||||
# IMPLEMENTATION DETAILS ###############################################
|
||||
def _validate_init(self, obj: Column, datatype: str):
|
||||
self.assertEqual(obj._datatype, datatype)
|
||||
self.assertEqual(obj.datatype, datatype)
|
||||
|
||||
def _validate_column(self, obj: Column, mock_conn: ServerConnection):
|
||||
# NodeObject basic properties
|
||||
self.assertIs(obj._conn, mock_conn)
|
||||
self.assertEqual(obj._oid, COLUMN_ROW['oid'])
|
||||
self.assertEqual(obj.oid, COLUMN_ROW['oid'])
|
||||
self.assertEqual(obj._name, COLUMN_ROW['name'])
|
||||
self.assertEqual(obj.name, COLUMN_ROW['name'])
|
||||
|
||||
# Column-specific basic properties
|
||||
self.assertEqual(obj._datatype, COLUMN_ROW['datatype'])
|
||||
self.assertEqual(obj.datatype, COLUMN_ROW['datatype'])
|
||||
self.assertEqual(obj._has_default_value, COLUMN_ROW['has_default_val'])
|
||||
self.assertEqual(obj.has_default_value, COLUMN_ROW['has_default_val'])
|
||||
self.assertEqual(obj._not_null, COLUMN_ROW['not_null'])
|
||||
self.assertEqual(obj.not_null, COLUMN_ROW['not_null'])
|
|
@ -0,0 +1,101 @@
|
|||
# --------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
from pgsmo.objects.database.database import Database
|
||||
from pgsmo.utils.querying import ServerConnection
|
||||
import tests.pgsmo_tests.utils as utils
|
||||
|
||||
DATABASE_ROW = {
|
||||
'name': 'dbname',
|
||||
'did': 123,
|
||||
'spcname': 'primary',
|
||||
'datallowconn': True,
|
||||
'cancreate': True,
|
||||
'owner': 10
|
||||
}
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
# CONSTRUCTION TESTS ###################################################
|
||||
def test_init_connected(self):
|
||||
props = [
|
||||
'_tablespace', 'tablespace',
|
||||
'_allow_conn', 'allow_conn',
|
||||
'_can_create', 'can_create',
|
||||
'_owner_oid'
|
||||
]
|
||||
colls = ['_schemas', 'schemas'] # When connected, these are actually defined
|
||||
name = 'dbname'
|
||||
mock_conn = ServerConnection(utils.MockConnection(None, name=name))
|
||||
db = Database(mock_conn, name)
|
||||
utils.validate_init(
|
||||
Database, name, mock_conn, db, props, colls,
|
||||
lambda obj: self._init_validation(obj, True)
|
||||
)
|
||||
|
||||
def test_init_disconnected(self):
|
||||
props = [
|
||||
'_tablespace', 'tablespace',
|
||||
'_allow_conn', 'allow_conn',
|
||||
'_can_create', 'can_create',
|
||||
'_owner_oid',
|
||||
'_schemas', 'schemas' # When not connected we want these to be set to None
|
||||
]
|
||||
colls = []
|
||||
name = 'dbname'
|
||||
mock_conn = ServerConnection(utils.MockConnection(None, name='notconnected'))
|
||||
db = Database(mock_conn, name)
|
||||
utils.validate_init(
|
||||
Database, name, mock_conn, db, props, colls,
|
||||
lambda obj: self._init_validation(obj, False)
|
||||
)
|
||||
|
||||
def test_from_node_query(self):
|
||||
utils.from_node_query_base(Database, DATABASE_ROW, self._validate_database)
|
||||
|
||||
def test_get_nodes_for_parent(self):
|
||||
# Use the test helper for this method
|
||||
utils.get_nodes_for_parent_base(Database, DATABASE_ROW, Database.get_nodes_for_parent, self._validate_database)
|
||||
|
||||
# METHOD TESTS #########################################################
|
||||
def test_refresh(self):
|
||||
# Setup: Create a database object and overwrite the reset method of the child objects
|
||||
db = Database(ServerConnection(utils.MockConnection(None, name=DATABASE_ROW['name'])), DATABASE_ROW['name'])
|
||||
db._schemas.reset = mock.MagicMock()
|
||||
|
||||
# If: I refresh the database
|
||||
db.refresh()
|
||||
|
||||
# Then: The mocks should have been called
|
||||
db._schemas.reset.assert_called_once()
|
||||
|
||||
# IMPLEMENTATION DETAILS ###############################################
|
||||
def _init_validation(self, obj: Database, is_connected: bool):
|
||||
self.assertEqual(obj._is_connected, is_connected)
|
||||
|
||||
def _validate_database(self, db: Database, mock_conn: ServerConnection):
|
||||
# NodeObject basic properties
|
||||
self.assertIs(db._conn, mock_conn)
|
||||
self.assertEqual(db._oid, DATABASE_ROW['did'])
|
||||
self.assertEqual(db.oid, DATABASE_ROW['did'])
|
||||
self.assertEqual(db._name, DATABASE_ROW['name'])
|
||||
self.assertEqual(db.name, DATABASE_ROW['name'])
|
||||
|
||||
# Database-specific basic properties
|
||||
self.assertEqual(db._tablespace, DATABASE_ROW['spcname'])
|
||||
self.assertEqual(db.tablespace, DATABASE_ROW['spcname'])
|
||||
self.assertEqual(db._allow_conn, DATABASE_ROW['datallowconn'])
|
||||
self.assertEqual(db.allow_conn, DATABASE_ROW['datallowconn'])
|
||||
self.assertEqual(db._can_create, DATABASE_ROW['cancreate'])
|
||||
self.assertEqual(db.can_create, DATABASE_ROW['cancreate'])
|
||||
self.assertEqual(db._owner_oid, DATABASE_ROW['owner'])
|
||||
|
||||
# Child objects
|
||||
self.assertFalse(db._is_connected)
|
||||
self.assertIsNone(db._schemas)
|
||||
self.assertIsNone(db.schemas)
|
|
@ -0,0 +1,48 @@
|
|||
# --------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
import unittest
|
||||
|
||||
from pgsmo.objects.role.role import Role
|
||||
from pgsmo.utils.querying import ServerConnection
|
||||
import tests.pgsmo_tests.utils as utils
|
||||
|
||||
|
||||
ROLE_ROW = {
|
||||
'rolname': 'role',
|
||||
'oid': 123,
|
||||
'rolcanlogin': True,
|
||||
'rolsuper': True
|
||||
}
|
||||
|
||||
|
||||
class TestRole(unittest.TestCase):
|
||||
# CONSTRUCTION TESTS ###################################################
|
||||
def test_init(self):
|
||||
props = ['_can_login', 'can_login', '_super', 'super']
|
||||
colls = []
|
||||
utils.init_base(Role, props, colls)
|
||||
|
||||
def test_from_node_query(self):
|
||||
utils.from_node_query_base(Role, ROLE_ROW, self._validate_role)
|
||||
|
||||
def test_get_nodes_for_parent(self):
|
||||
# Use the test helper
|
||||
utils.get_nodes_for_parent_base(Role, ROLE_ROW, Role.get_nodes_for_parent, self._validate_role)
|
||||
|
||||
# IMPLEMENTATION DETAILS ###############################################
|
||||
def _validate_role(self, role: Role, mock_conn: ServerConnection):
|
||||
# NodeObject basic properties
|
||||
self.assertIs(role._conn, mock_conn)
|
||||
self.assertEqual(role._oid, ROLE_ROW['oid'])
|
||||
self.assertEqual(role.oid, ROLE_ROW['oid'])
|
||||
self.assertEqual(role._name, ROLE_ROW['rolname'])
|
||||
self.assertEqual(role.name, ROLE_ROW['rolname'])
|
||||
|
||||
# Role-specific basic properties
|
||||
self.assertEqual(role._can_login, ROLE_ROW['rolcanlogin'])
|
||||
self.assertEqual(role.can_login, ROLE_ROW['rolcanlogin'])
|
||||
self.assertEqual(role._super, ROLE_ROW['rolsuper'])
|
||||
self.assertEqual(role.super, ROLE_ROW['rolsuper'])
|
|
@ -0,0 +1,70 @@
|
|||
# --------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
from pgsmo.objects.node_object import NodeCollection
|
||||
from pgsmo.objects.schema.schema import Schema
|
||||
from pgsmo.utils.querying import ServerConnection
|
||||
import tests.pgsmo_tests.utils as utils
|
||||
|
||||
|
||||
SCHEMA_ROW = {
|
||||
'name': 'schema',
|
||||
'oid': 123,
|
||||
'can_create': True,
|
||||
'has_usage': True
|
||||
}
|
||||
|
||||
|
||||
class TestSchema(unittest.TestCase):
|
||||
# CONSTRUCTION TESTS ###################################################
|
||||
def test_init(self):
|
||||
props = ['_can_create', 'can_create', '_has_usage', 'has_usage']
|
||||
collections = ['_tables', 'tables', '_views', 'views']
|
||||
utils.init_base(Schema, props, collections)
|
||||
|
||||
def test_from_node_query(self):
|
||||
utils.from_node_query_base(Schema, SCHEMA_ROW, self._validate_schema)
|
||||
|
||||
def test_from_nodes_for_parent(self):
|
||||
# Use the test helper for this method
|
||||
utils.get_nodes_for_parent_base(Schema, SCHEMA_ROW, Schema.get_nodes_for_parent, self._validate_schema)
|
||||
|
||||
# METHOD TESTS #########################################################
|
||||
def test_refresh(self):
|
||||
# Setup: Create a schema object and mock up the node collection reset methods
|
||||
mock_conn = ServerConnection(utils.MockConnection(None))
|
||||
schema = Schema._from_node_query(mock_conn, **SCHEMA_ROW)
|
||||
schema._tables.reset = mock.MagicMock()
|
||||
schema._views.reset = mock.MagicMock()
|
||||
|
||||
# If: I refresh a schema object
|
||||
schema.refresh()
|
||||
|
||||
# Then: The child object node collections should have been reset
|
||||
schema._tables.reset.assert_called_once()
|
||||
schema._views.reset.assert_called_once()
|
||||
|
||||
def _validate_schema(self, schema: Schema, mock_conn: ServerConnection):
|
||||
# NodeObject basic properties
|
||||
self.assertIs(schema._conn, mock_conn)
|
||||
self.assertEqual(schema._oid, SCHEMA_ROW['oid'])
|
||||
self.assertEqual(schema.oid, SCHEMA_ROW['oid'])
|
||||
self.assertEqual(schema._name, SCHEMA_ROW['name'])
|
||||
self.assertEqual(schema.name, SCHEMA_ROW['name'])
|
||||
|
||||
# Schema-specific basic properties
|
||||
self.assertEqual(schema._can_create, SCHEMA_ROW['can_create'])
|
||||
self.assertEqual(schema.can_create, SCHEMA_ROW['can_create'])
|
||||
self.assertEqual(schema._has_usage, SCHEMA_ROW['has_usage'])
|
||||
self.assertEqual(schema.has_usage, SCHEMA_ROW['has_usage'])
|
||||
|
||||
# Child objects
|
||||
self.assertIsInstance(schema._tables, NodeCollection)
|
||||
self.assertIs(schema.tables, schema._tables)
|
||||
self.assertIsInstance(schema._views, NodeCollection)
|
||||
self.assertIs(schema.views, schema._views)
|
|
@ -0,0 +1,48 @@
|
|||
# --------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
import unittest
|
||||
|
||||
from pgsmo.objects.node_object import NodeCollection
|
||||
from pgsmo.objects.server.server import Server
|
||||
from pgsmo.utils.querying import ServerConnection
|
||||
import tests.pgsmo_tests.utils as utils
|
||||
|
||||
|
||||
class TestServer(unittest.TestCase):
|
||||
def test_init(self):
|
||||
# If: I construct a new server object
|
||||
host = 'host'
|
||||
port = '1234'
|
||||
dbname = 'dbname'
|
||||
mock_conn = utils.MockConnection(None, name=dbname, host=host, port=port)
|
||||
server = Server(mock_conn)
|
||||
|
||||
# Then:
|
||||
# ... The assigned properties should be assigned
|
||||
self.assertIsInstance(server._conn, ServerConnection)
|
||||
self.assertIsInstance(server.connection, ServerConnection)
|
||||
self.assertIs(server.connection.connection, mock_conn)
|
||||
self.assertEqual(server._host, host)
|
||||
self.assertEqual(server.host, host)
|
||||
self.assertEqual(server._port, int(port))
|
||||
self.assertEqual(server.port, int(port))
|
||||
self.assertEqual(server._maintenance_db, dbname)
|
||||
self.assertEqual(server.maintenance_db, dbname)
|
||||
self.assertTupleEqual(server.version, server._conn.version)
|
||||
|
||||
# ... The optional properties should be assigned to None
|
||||
self.assertIsNone(server._in_recovery)
|
||||
self.assertIsNone(server.in_recovery)
|
||||
self.assertIsNone(server._wal_paused)
|
||||
self.assertIsNone(server.wal_paused)
|
||||
|
||||
# ... The child object collections should be assigned to NodeCollections
|
||||
self.assertIsInstance(server._databases, NodeCollection)
|
||||
self.assertIs(server.databases, server._databases)
|
||||
self.assertIsInstance(server._roles, NodeCollection)
|
||||
self.assertIs(server.roles, server._roles)
|
||||
self.assertIsInstance(server._tablespaces, NodeCollection)
|
||||
self.assertIs(server.tablespaces, server._tablespaces)
|
|
@ -0,0 +1,62 @@
|
|||
# --------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
from pgsmo.objects.node_object import NodeCollection
|
||||
from pgsmo.objects.table.table import Table
|
||||
from pgsmo.utils.querying import ServerConnection
|
||||
import tests.pgsmo_tests.utils as utils
|
||||
|
||||
TABLE_ROW = {
|
||||
'name': 'tablename',
|
||||
'oid': 123
|
||||
}
|
||||
|
||||
|
||||
class TestTable(unittest.TestCase):
|
||||
# CONSTRUCTION TESTS ###################################################
|
||||
def test_init(self):
|
||||
props = []
|
||||
colls = ['_columns', 'columns']
|
||||
utils.init_base(Table, props, colls)
|
||||
|
||||
def test_from_node_query(self):
|
||||
utils.from_node_query_base(Table, TABLE_ROW, self._validate_table)
|
||||
|
||||
def test_from_nodes_for_parent(self):
|
||||
utils.get_nodes_for_parent_base(
|
||||
Table,
|
||||
TABLE_ROW,
|
||||
lambda conn: Table.get_nodes_for_parent(conn, 0),
|
||||
self._validate_table
|
||||
)
|
||||
|
||||
# METHOD TESTS #########################################################
|
||||
def test_refresh(self):
|
||||
# Setup: Create a table object and mock up the node collection reset method
|
||||
mock_conn = ServerConnection(utils.MockConnection(None))
|
||||
table = Table._from_node_query(mock_conn, **TABLE_ROW)
|
||||
table._columns.reset = mock.MagicMock()
|
||||
|
||||
# If: I refresh a table object
|
||||
table.refresh()
|
||||
|
||||
# Then: The child object node collections should have been reset
|
||||
table._columns.reset.assert_called_once()
|
||||
|
||||
# IMPLEMENTATION DETAILS ###############################################
|
||||
def _validate_table(self, table: Table, mock_conn: ServerConnection):
|
||||
# NodeObject basic properties
|
||||
self.assertIs(table._conn, mock_conn)
|
||||
self.assertEqual(table._oid, TABLE_ROW['oid'])
|
||||
self.assertEqual(table.oid, TABLE_ROW['oid'])
|
||||
self.assertEqual(table._name, TABLE_ROW['name'])
|
||||
self.assertEqual(table.name, TABLE_ROW['name'])
|
||||
|
||||
# Child objects
|
||||
self.assertIsInstance(table._columns, NodeCollection)
|
||||
self.assertIs(table.columns, table._columns)
|
|
@ -0,0 +1,47 @@
|
|||
# --------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
import unittest
|
||||
|
||||
from pgsmo.objects.tablespace.tablespace import Tablespace
|
||||
from pgsmo.utils.querying import ServerConnection
|
||||
import tests.pgsmo_tests.utils as utils
|
||||
|
||||
NODE_ROW = {
|
||||
'name': 'test',
|
||||
'oid': 123,
|
||||
'owner': 10
|
||||
}
|
||||
|
||||
|
||||
class TestTablespace(unittest.TestCase):
|
||||
# CONSTRUCTION TESTS ###################################################
|
||||
def test_init(self):
|
||||
props = ['owner', '_owner']
|
||||
utils.init_base(Tablespace, props, [])
|
||||
|
||||
def test_from_node_query(self):
|
||||
utils.from_node_query_base(Tablespace, NODE_ROW, self._validate_tablespace)
|
||||
|
||||
def test_get_nodes_for_parent(self):
|
||||
utils.get_nodes_for_parent_base(
|
||||
Tablespace,
|
||||
NODE_ROW,
|
||||
Tablespace.get_nodes_for_parent,
|
||||
self._validate_tablespace
|
||||
)
|
||||
|
||||
# IMPLEMENTATION DETAILS ###############################################
|
||||
def _validate_tablespace(self, obj: Tablespace, mock_conn: ServerConnection):
|
||||
# NodeObject basic properties
|
||||
self.assertIs(obj._conn, mock_conn)
|
||||
self.assertEqual(obj._oid, NODE_ROW['oid'])
|
||||
self.assertEqual(obj.oid, NODE_ROW['oid'])
|
||||
self.assertEqual(obj._name, NODE_ROW['name'])
|
||||
self.assertEqual(obj.name, NODE_ROW['name'])
|
||||
|
||||
# Tablespace-specific properties
|
||||
self.assertEqual(obj._owner, NODE_ROW['owner'])
|
||||
self.assertEqual(obj.owner, NODE_ROW['owner'])
|
|
@ -0,0 +1,62 @@
|
|||
# --------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
from pgsmo.objects.node_object import NodeCollection
|
||||
from pgsmo.objects.view.view import View
|
||||
from pgsmo.utils.querying import ServerConnection
|
||||
import tests.pgsmo_tests.utils as utils
|
||||
|
||||
NODE_ROW = {
|
||||
'name': 'viewname',
|
||||
'oid': 123
|
||||
}
|
||||
|
||||
|
||||
class TestTable(unittest.TestCase):
|
||||
# CONSTRUCTION TESTS ###################################################
|
||||
def test_init(self):
|
||||
props = []
|
||||
colls = ['_columns', 'columns']
|
||||
utils.init_base(View, props, colls)
|
||||
|
||||
def test_from_node_query(self):
|
||||
utils.from_node_query_base(View, NODE_ROW, self._validate_view)
|
||||
|
||||
def test_from_nodes_for_parent(self):
|
||||
utils.get_nodes_for_parent_base(
|
||||
View,
|
||||
NODE_ROW,
|
||||
lambda conn: View.get_nodes_for_parent(conn, 0),
|
||||
self._validate_view
|
||||
)
|
||||
|
||||
# METHOD TESTS #########################################################
|
||||
def test_refresh(self):
|
||||
# Setup: Create a table object and mock up the node collection reset method
|
||||
mock_conn = ServerConnection(utils.MockConnection(None))
|
||||
table = View._from_node_query(mock_conn, **NODE_ROW)
|
||||
table._columns.reset = mock.MagicMock()
|
||||
|
||||
# If: I refresh a table object
|
||||
table.refresh()
|
||||
|
||||
# Then: The child object node collections should have been reset
|
||||
table._columns.reset.assert_called_once()
|
||||
|
||||
# IMPLEMENTATION DETAILS ###############################################
|
||||
def _validate_view(self, view: View, mock_conn: ServerConnection):
|
||||
# NodeObject basic properties
|
||||
self.assertIs(view._conn, mock_conn)
|
||||
self.assertEqual(view._oid, NODE_ROW['oid'])
|
||||
self.assertEqual(view.oid, NODE_ROW['oid'])
|
||||
self.assertEqual(view._name, NODE_ROW['name'])
|
||||
self.assertEqual(view.name, NODE_ROW['name'])
|
||||
|
||||
# Child objects
|
||||
self.assertIsInstance(view._columns, NodeCollection)
|
||||
self.assertIs(view.columns, view._columns)
|
|
@ -0,0 +1,73 @@
|
|||
# --------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
import unittest
|
||||
|
||||
import pgsmo.utils as pgsmo_utils
|
||||
import tests.pgsmo_tests.utils as utils
|
||||
|
||||
|
||||
class TestServerConnection(unittest.TestCase):
|
||||
def test_server_conn_init(self):
|
||||
# Setup: Create a mock connection with an 'interesting' version
|
||||
mock_conn = utils.MockConnection(None, version='100216')
|
||||
|
||||
# If: I initialize a server connection
|
||||
# noinspection PyTypeChecker
|
||||
server_conn = pgsmo_utils.querying.ServerConnection(mock_conn)
|
||||
|
||||
# Then: The properties should be properly set
|
||||
self.assertEqual(server_conn._conn, mock_conn)
|
||||
self.assertEqual(server_conn.connection, mock_conn)
|
||||
expected_dict = {'dbname': 'postgres', 'host': 'localhost', 'port': '25565'}
|
||||
self.assertDictEqual(server_conn._dsn_parameters, expected_dict)
|
||||
self.assertDictEqual(server_conn.dsn_parameters, expected_dict)
|
||||
self.assertEqual('pg', server_conn.server_type)
|
||||
self.assertTupleEqual((10, 2, 16), server_conn.version)
|
||||
|
||||
def test_execute_dict_success(self):
|
||||
# Setup: Create a mock server connection that will return a result set
|
||||
mock_cursor = utils.MockCursor(utils.get_mock_results())
|
||||
mock_conn = utils.MockConnection(mock_cursor)
|
||||
# noinspection PyTypeChecker
|
||||
server_conn = pgsmo_utils.querying.ServerConnection(mock_conn)
|
||||
|
||||
# If: I execute a query as a dictionary
|
||||
results = server_conn.execute_dict('SELECT * FROM pg_class')
|
||||
|
||||
# Then:
|
||||
# ... Both the columns and the results should be returned
|
||||
self.assertIsInstance(results, tuple)
|
||||
self.assertEqual(len(results), 2)
|
||||
|
||||
# ... I should get a list of columns returned to me
|
||||
cols = results[0]
|
||||
self.assertIsInstance(cols, list)
|
||||
self.assertListEqual(cols, mock_cursor.description)
|
||||
|
||||
# ... I should get the results formatted as a list of dictionaries
|
||||
rows = results[1]
|
||||
self.assertIsInstance(rows, list)
|
||||
for idx, row in enumerate(rows):
|
||||
self.assertDictEqual(row, mock_cursor._results[1][idx])
|
||||
|
||||
# ... The cursor should be closed
|
||||
mock_cursor.close.assert_called_once()
|
||||
|
||||
def test_execute_dict_fail(self):
|
||||
# Setup: Create a mock server connection that will raise an exception
|
||||
mock_cursor = utils.MockCursor(None, throw_on_execute=True)
|
||||
mock_conn = utils.MockConnection(mock_cursor)
|
||||
# noinspection PyTypeChecker
|
||||
server_conn = pgsmo_utils.querying.ServerConnection(mock_conn)
|
||||
|
||||
# If: I execute a query as a dictionary
|
||||
# Then:
|
||||
# ... I should get an exception
|
||||
with self.assertRaises(Exception):
|
||||
server_conn.execute_dict('SELECT * FROM pg_class')
|
||||
|
||||
# ... The cursor should be closed
|
||||
mock_cursor.close.assert_called_once()
|
|
@ -0,0 +1,178 @@
|
|||
# --------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
import os.path as path
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
import jinja2
|
||||
|
||||
import pgsmo.utils as pgsmo_utils
|
||||
|
||||
|
||||
class TestTemplatingUtils(unittest.TestCase):
|
||||
# GET_TEMPLATE_ROOT TESTS ##############################################
|
||||
def test_get_template_root(self):
|
||||
# If: I attempt to get the template root of this file
|
||||
root = pgsmo_utils.templating.get_template_root(__file__, 'templates')
|
||||
|
||||
# Then: The output should match what I expected
|
||||
expected = path.join(path.dirname(__file__), 'templates')
|
||||
self.assertEqual(root, expected)
|
||||
|
||||
# GET_TEMPLATE_PATH TESTS ##############################################
|
||||
def test_get_template_path_no_match(self):
|
||||
# Setup: Create a mock os walker
|
||||
with mock.patch('pgsmo.utils.templating.os.walk', _os_walker, create=True):
|
||||
with self.assertRaises(ValueError):
|
||||
# If: I attempt to get the path of a template when it does not exist
|
||||
# Then: An exception should be thrown
|
||||
pgsmo_utils.templating.get_template_path(TEMPLATE_ROOT_NAME, 'doesnotexist.sql', (9, 0, 0))
|
||||
|
||||
def test_get_template_path_default(self):
|
||||
# Setup: Create a mock os walker
|
||||
with mock.patch('pgsmo.utils.templating.os.walk', _os_walker, create=True):
|
||||
# If: I attempt to get a template path when there is not a good version match
|
||||
template_path = pgsmo_utils.templating.get_template_path(TEMPLATE_ROOT_NAME, 'template.sql', (8, 1, 0))
|
||||
|
||||
# Then: The template path should point to the default folder
|
||||
self.assertEqual(template_path, path.join(TEMPLATE_ROOT_NAME, '+default', 'template.sql'))
|
||||
|
||||
def test_get_template_path_exact_version_match(self):
|
||||
# Setup: Create a mock os walker
|
||||
with mock.patch('pgsmo.utils.templating.os.walk', _os_walker, create=True):
|
||||
# If: I attempt to get the path of a template when there's a exact version match
|
||||
template_path = pgsmo_utils.templating.get_template_path(TEMPLATE_ROOT_NAME, 'template.sql', (9, 0, 0))
|
||||
|
||||
# Then: The returned path must match the exact version
|
||||
self.assertEqual(template_path, path.join(TEMPLATE_ROOT_NAME, '9.0', 'template.sql'))
|
||||
|
||||
def test_get_template_path_plus_match(self):
|
||||
# Setup: Create a mock os walker
|
||||
with mock.patch('pgsmo.utils.templating.os.walk', _os_walker, create=True):
|
||||
# If: I attempt to get a template path when there is a version range that matches
|
||||
template_path = pgsmo_utils.templating.get_template_path(TEMPLATE_ROOT_NAME, 'template.sql', (9, 3, 0))
|
||||
|
||||
# Then: The returned path should match the next lowest version range
|
||||
self.assertEqual(template_path, path.join(TEMPLATE_ROOT_NAME, '9.2_plus', 'template.sql'))
|
||||
|
||||
def test_get_template_path_invalid_folder(self):
|
||||
# Setup: Create a mock os walker that has an invalid folder in it
|
||||
with mock.patch('pgsmo.utils.templating.os.walk', _bad_os_walker, create=True):
|
||||
with self.assertRaises(ValueError):
|
||||
# If: I attempt to get a template path when there is an invalid folder in the template folder
|
||||
# Then: I should get an exception
|
||||
pgsmo_utils.templating.get_template_path(TEMPLATE_ROOT_NAME, 'template.sql', (9, 0, 0))
|
||||
|
||||
# RENDER_TEMPLATE TESTS ################################################
|
||||
def test_render_template(self):
|
||||
# NOTE: This test has an external dependency on dummy_template.txt
|
||||
# If: I render a string
|
||||
template_file = 'dummy_template.txt'
|
||||
template_folder = path.dirname(__file__)
|
||||
template_path = path.normpath(path.join(template_folder, template_file))
|
||||
rendered = pgsmo_utils.templating.render_template(template_path, foo='bar')
|
||||
|
||||
# Then:
|
||||
# ... The output should be properly rendered
|
||||
self.assertEqual(rendered, 'bar')
|
||||
|
||||
# ... The environment should be cached
|
||||
self.assertIsInstance(pgsmo_utils.templating.TEMPLATE_ENVIRONMENTS, dict)
|
||||
env = pgsmo_utils.templating.TEMPLATE_ENVIRONMENTS.get(template_folder)
|
||||
self.assertIsInstance(env, jinja2.Environment)
|
||||
|
||||
# ... The environment should have the proper filters defined
|
||||
self.assertEquals(env.filters['qtLiteral'], pgsmo_utils.templating.qt_literal)
|
||||
self.assertEquals(env.filters['qtIdent'], pgsmo_utils.templating.qt_ident)
|
||||
self.assertEquals(env.filters['qtTypeIdent'], pgsmo_utils.templating.qt_type_ident)
|
||||
|
||||
def test_render_template_cached(self):
|
||||
# NOTE: This test has an external dependency on dummy_template.txt
|
||||
# If:
|
||||
# ... I render a string
|
||||
template_file = 'dummy_template.txt'
|
||||
template_folder = path.dirname(__file__)
|
||||
template_path = path.normpath(path.join(template_folder, template_file))
|
||||
rendered1 = pgsmo_utils.templating.render_template(template_path, foo='bar')
|
||||
env1 = pgsmo_utils.templating.TEMPLATE_ENVIRONMENTS.get(template_folder)
|
||||
|
||||
# ... I render the same string
|
||||
rendered2 = pgsmo_utils.templating.render_template(template_path, foo='bar')
|
||||
env2 = pgsmo_utils.templating.TEMPLATE_ENVIRONMENTS.get(template_folder)
|
||||
|
||||
# Then: The environments used should be literally the same
|
||||
self.assertEqual(rendered1, rendered2)
|
||||
self.assertIs(env1, env2)
|
||||
|
||||
# RENDER_TEMPLATE_STRING TESTS #########################################
|
||||
def test_render_template_string(self):
|
||||
# NOTE: doing very minimal test here since this function just uses jinja2 functionality
|
||||
# If: I render a string
|
||||
rendered = pgsmo_utils.templating.render_template_string('{{foo}}', foo='bar')
|
||||
|
||||
# Then: The string should be properly rendered
|
||||
self.assertEqual(rendered, 'bar')
|
||||
|
||||
|
||||
# Mock setup for a template tree
|
||||
TEMPLATE_ROOT_NAME = 'templates'
|
||||
TEMPLATE_EXACT_1 = (path.join(TEMPLATE_ROOT_NAME, '9.0'), [], ['template.sql'])
|
||||
TEMPLATE_EXACT_2 = (path.join(TEMPLATE_ROOT_NAME, '9.1'), [], ['template.sql'])
|
||||
TEMPLATE_PLUS_1 = (path.join(TEMPLATE_ROOT_NAME, '9.2_plus'), [], ['template.sql'])
|
||||
TEMPLATE_PLUS_2 = (path.join(TEMPLATE_ROOT_NAME, '9.4_plus'), [], ['template.sql'])
|
||||
TEMPLATE_DEFAULT = (path.join(TEMPLATE_ROOT_NAME, '+default'), [], ['template.sql'])
|
||||
|
||||
# Tests that skipped folders are not returned
|
||||
TEMPLATE_SKIP = (path.join(TEMPLATE_ROOT_NAME, 'macros'), [], ['template.sql'])
|
||||
|
||||
# Root tree
|
||||
TEMPLATE_ROOT = (
|
||||
TEMPLATE_ROOT_NAME,
|
||||
[
|
||||
TEMPLATE_DEFAULT[0],
|
||||
TEMPLATE_EXACT_1[0],
|
||||
TEMPLATE_EXACT_2[0],
|
||||
TEMPLATE_PLUS_1[0],
|
||||
TEMPLATE_PLUS_2[0],
|
||||
TEMPLATE_SKIP[0]
|
||||
],
|
||||
[]
|
||||
)
|
||||
|
||||
|
||||
def _os_walker(x):
|
||||
if x == TEMPLATE_ROOT_NAME:
|
||||
yield TEMPLATE_ROOT
|
||||
yield TEMPLATE_DEFAULT
|
||||
yield TEMPLATE_EXACT_1
|
||||
yield TEMPLATE_EXACT_2
|
||||
yield TEMPLATE_PLUS_1
|
||||
yield TEMPLATE_PLUS_2
|
||||
yield TEMPLATE_SKIP
|
||||
if x == TEMPLATE_DEFAULT[0]:
|
||||
yield TEMPLATE_DEFAULT
|
||||
if x == TEMPLATE_EXACT_1[0]:
|
||||
yield TEMPLATE_EXACT_1
|
||||
if x == TEMPLATE_EXACT_2[0]:
|
||||
yield TEMPLATE_EXACT_2
|
||||
if x == TEMPLATE_PLUS_1[0]:
|
||||
yield TEMPLATE_PLUS_1
|
||||
if x == TEMPLATE_PLUS_2[0]:
|
||||
yield TEMPLATE_PLUS_2
|
||||
if x == TEMPLATE_SKIP[0]:
|
||||
yield TEMPLATE_SKIP
|
||||
|
||||
|
||||
TEMPLATE_BAD = (path.join(TEMPLATE_ROOT_NAME, 'bad_folder'), [], ['template.sql'])
|
||||
TEMPLATE_BAD_ROOT = (TEMPLATE_ROOT_NAME, [TEMPLATE_BAD[0]], [])
|
||||
|
||||
|
||||
def _bad_os_walker(x):
|
||||
if x == TEMPLATE_ROOT_NAME:
|
||||
yield TEMPLATE_BAD_ROOT
|
||||
yield TEMPLATE_BAD
|
||||
if x == TEMPLATE_BAD[0]:
|
||||
yield TEMPLATE_BAD
|
|
@ -0,0 +1,84 @@
|
|||
# --------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
import unittest
|
||||
|
||||
import pgsmo.utils as pgsmo_utils
|
||||
|
||||
|
||||
class TestTemplatingFilters(unittest.TestCase):
|
||||
# SCAN KEYWORD EXTRA LOOKUP TESTS ######################################
|
||||
def test_scan_keyword_extra_lookup_extra(self):
|
||||
# If: I scan for a keyword that exists in the extra keywords
|
||||
output = pgsmo_utils.templating.scan_keyword_extra_lookup('connect')
|
||||
|
||||
# Then: I should get back the extra keyword info
|
||||
self.assertEqual(output, pgsmo_utils.templating._EXTRA_KEYWORDS['connect'])
|
||||
|
||||
def test_scan_keyword_extra_lookup_standard(self):
|
||||
# If: I scan for a keyword that exists in the standard keywords
|
||||
output = pgsmo_utils.templating.scan_keyword_extra_lookup('abort')
|
||||
|
||||
# Then: I should get back the standard keyword value
|
||||
self.assertEqual(output, pgsmo_utils.templating._KEYWORD_DICT['abort'])
|
||||
|
||||
def test_scan_keyword_extra_lookup_none(self):
|
||||
# If: I scan for a keyword that doesn't exist
|
||||
output = pgsmo_utils.templating.scan_keyword_extra_lookup('does_not_exist')
|
||||
|
||||
# Then: I should get back None
|
||||
self.assertIsNone(output)
|
||||
|
||||
# QTLITERAL TESTS ######################################################
|
||||
def test_qtliteral_no_encoding(self):
|
||||
# If: I provide a value that doesn't have an encoding
|
||||
output = pgsmo_utils.templating.qt_literal(4)
|
||||
|
||||
# Then: I should get a quoted literal back
|
||||
self.assertEqual(output, '4')
|
||||
|
||||
def test_qtliteral_bytes(self):
|
||||
# If: I provide a byte array
|
||||
output = pgsmo_utils.templating.qt_literal(b'123')
|
||||
|
||||
# Then: I should get the byte array decoded as UTF-8 back
|
||||
self.assertEqual(output, "'123'::bytea")
|
||||
|
||||
def test_qtliteral_string(self):
|
||||
# If: I provide a string
|
||||
output = pgsmo_utils.templating.qt_literal("123")
|
||||
|
||||
# Then: I should get the quoted string back
|
||||
self.assertEqual(output, "'123'")
|
||||
|
||||
# NEEDS QUOTING TESTS ##################################################
|
||||
# TODO: Add tests that are more based on scenarios and less on code paths
|
||||
def test_needs_quoting_int(self):
|
||||
# If: An int is provided
|
||||
# Then: It should always be quoted
|
||||
self.assertTrue(pgsmo_utils.templating.needs_quoting(4, False))
|
||||
self.assertTrue(pgsmo_utils.templating.needs_quoting(4, True))
|
||||
|
||||
def test_needs_quoting_type_legal_spaces(self):
|
||||
# If: A type is provided that legally has a space in it
|
||||
# Then: It shouldn't be quoted
|
||||
self.assertFalse(pgsmo_utils.templating.needs_quoting('time with time zone', True))
|
||||
self.assertFalse(pgsmo_utils.templating.needs_quoting('time with time zone[]', True))
|
||||
|
||||
def test_needs_quoting_type_already_quoted(self):
|
||||
# If: A type is provided that is already quoted
|
||||
# Then: It shouldn't be quoted
|
||||
self.assertFalse(pgsmo_utils.templating.needs_quoting('"int"', True))
|
||||
|
||||
def test_needs_quoting_numeric_value(self):
|
||||
# If: A value is numeric (? starts with 0-9)
|
||||
# Then: It should be quoted
|
||||
self.assertTrue(pgsmo_utils.templating.needs_quoting('2000', False))
|
||||
|
||||
def test_needs_quoting_non_alphanumeric(self):
|
||||
# If: A value is not lowercase alphanumeric
|
||||
# Then: It should be quoted
|
||||
self.assertTrue(pgsmo_utils.templating.needs_quoting('Something+Else', False))
|
||||
self.assertTrue(pgsmo_utils.templating.needs_quoting('Something+Else', True))
|
|
@ -0,0 +1,179 @@
|
|||
# --------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
from psycopg2 import DatabaseError
|
||||
from psycopg2.extensions import Column, connection
|
||||
|
||||
from pgsmo.objects.node_object import NodeObject, NodeCollection
|
||||
from pgsmo.utils.querying import ServerConnection
|
||||
|
||||
|
||||
# MOCK CONNECTION ##########################################################
|
||||
def get_mock_columns(col_count: int) -> List[Column]:
|
||||
return [Column(f'column{i}', None, 10, 10, None, None, True) for i in range(0, col_count+1)]
|
||||
|
||||
|
||||
def get_named_mock_columns(col_names: List[str]) -> List[Column]:
|
||||
return [Column(x, None, 10, 10, None, None, True) for x in col_names]
|
||||
|
||||
|
||||
def get_mock_results(col_count: int=5, row_count: int=5) -> Tuple[List[Column], List[dict]]:
|
||||
rows = []
|
||||
cols = get_mock_columns(col_count)
|
||||
for i in range(0, len(cols)):
|
||||
# Add the column to the rows
|
||||
for j in range(0, row_count+1):
|
||||
if len(rows) >= j:
|
||||
rows.append({})
|
||||
|
||||
rows[j][cols[i].name] = f'value{j}.{i}'
|
||||
|
||||
return cols, rows
|
||||
|
||||
|
||||
class MockCursor:
|
||||
def __init__(self, results: Optional[Tuple[List[Column], List[dict]]], throw_on_execute=False):
|
||||
# Setup the results, that will change value once the cursor is executed
|
||||
self.description = None
|
||||
self.rowcount = None
|
||||
self._results = results
|
||||
self._throw_on_execute = throw_on_execute
|
||||
|
||||
# Define iterator state
|
||||
self._has_been_read = False
|
||||
self._iter_index = 0
|
||||
|
||||
# Define mocks for the methods of the cursor
|
||||
self.execute = mock.MagicMock(side_effect=self._execute)
|
||||
self.close = mock.MagicMock()
|
||||
|
||||
def __iter__(self):
|
||||
# If we haven't read yet, raise an error
|
||||
# Or if we have read but we're past the end of the list, raise an error
|
||||
if not self._has_been_read or self._iter_index > len(self._results[1]):
|
||||
raise StopIteration
|
||||
|
||||
# From python 3.6+ this dicts preserve order, so this isn't an issue
|
||||
yield list(self._results[1][self._iter_index].values())
|
||||
self._iter_index += 1
|
||||
|
||||
def _execute(self, query, params):
|
||||
# Raise error if that was expected, otherwise set the output
|
||||
if self._throw_on_execute:
|
||||
raise DatabaseError()
|
||||
|
||||
self.description = self._results[0]
|
||||
self.rowcount = len(self._results[1])
|
||||
self._has_been_read = True
|
||||
|
||||
|
||||
class MockConnection(connection):
|
||||
def __init__(
|
||||
self,
|
||||
cur: Optional[MockCursor],
|
||||
version: str='90602',
|
||||
name: str='postgres',
|
||||
host: str='localhost',
|
||||
port: str='25565'):
|
||||
# Setup the properties
|
||||
self._server_version = version
|
||||
|
||||
# Setup mocks for the connection
|
||||
self.close = mock.MagicMock()
|
||||
self.cursor = mock.MagicMock(return_value=cur)
|
||||
|
||||
dsn_params = {'dbname': name, 'host': host, 'port': port}
|
||||
self.get_dsn_parameters = mock.MagicMock(return_value=dsn_params)
|
||||
|
||||
@property
|
||||
def server_version(self):
|
||||
return self._server_version
|
||||
|
||||
|
||||
# OBJECT TEST HELPERS ######################################################
|
||||
def get_nodes_for_parent_base(class_, data: dict, get_nodes_for_parent: Callable, validate_obj: Callable):
|
||||
# Setup: Create a mockup server connection
|
||||
mock_cur = MockCursor((get_named_mock_columns(list(data.keys())), [data for i in range(0, 6)]))
|
||||
mock_conn = ServerConnection(MockConnection(mock_cur))
|
||||
|
||||
# ... Create a mock template renderer
|
||||
mock_render = mock.MagicMock(return_value="SQL")
|
||||
mock_template_path = mock.MagicMock(return_value="path")
|
||||
|
||||
# ... Create a testcase for calling asserts with
|
||||
test_case = unittest.TestCase('__init__')
|
||||
|
||||
# ... Patch the templating
|
||||
with mock.patch(class_.__module__ + '.templating.render_template', mock_render, create=True):
|
||||
with mock.patch(class_.__module__ + '.templating.get_template_path', mock_template_path, create=True):
|
||||
# If: ask for a collection of nodes
|
||||
output = get_nodes_for_parent(mock_conn)
|
||||
|
||||
# Then:
|
||||
# ... The output should be a list of objects
|
||||
test_case.assertIsInstance(output, list)
|
||||
|
||||
for obj in output:
|
||||
# ... The object must be the class that was passed in
|
||||
test_case.assertIsInstance(obj, class_)
|
||||
|
||||
# ... Call the validator on the object
|
||||
validate_obj(obj, mock_conn)
|
||||
|
||||
|
||||
def from_node_query_base(class_, data: dict, validate_obj: Callable):
|
||||
# If: I create a new object from a node row
|
||||
mock_conn = ServerConnection(MockConnection(None))
|
||||
obj = class_._from_node_query(mock_conn, **data)
|
||||
|
||||
# Then:
|
||||
# ... The returned object must be an instance of the class
|
||||
test_case = unittest.TestCase('__init__')
|
||||
test_case.assertIsInstance(obj, NodeObject)
|
||||
test_case.assertIsInstance(obj, class_)
|
||||
|
||||
# ... Call the validation function
|
||||
validate_obj(obj, mock_conn)
|
||||
|
||||
|
||||
def init_base(class_, props: List[str], collections: List[str], custom_validation: Callable=None):
|
||||
# If: I create an instance of the provided class
|
||||
mock_conn = ServerConnection(MockConnection(None))
|
||||
name = 'test'
|
||||
obj = class_(mock_conn, name)
|
||||
|
||||
validate_init(class_, name, mock_conn, obj, props, collections, custom_validation)
|
||||
|
||||
|
||||
def validate_init(class_, name, mock_conn, obj, props: List[str], collections: List[str],
|
||||
custom_validation: Callable=None):
|
||||
# Then:
|
||||
# ... The object must be of the type that was provided
|
||||
test_case = unittest.TestCase('__init__')
|
||||
test_case.assertIsInstance(obj, NodeObject)
|
||||
test_case.assertIsInstance(obj, class_)
|
||||
|
||||
# ... The NodeObject basic properties should be set up appropriately
|
||||
test_case.assertIs(obj._conn, mock_conn)
|
||||
test_case.assertEqual(obj._name, name)
|
||||
test_case.assertEqual(obj.name, name)
|
||||
test_case.assertIsNone(obj._oid)
|
||||
test_case.assertIsNone(obj.oid)
|
||||
|
||||
# ... The rest of the properties should be none
|
||||
for prop in props:
|
||||
test_case.assertIsNone(getattr(obj, prop))
|
||||
|
||||
# ... The child properties should be assigned to node collections
|
||||
for coll in collections:
|
||||
test_case.assertIsInstance(getattr(obj, coll), NodeCollection)
|
||||
|
||||
# ... Run the custom validation
|
||||
if custom_validation is not None:
|
||||
custom_validation(obj)
|
Загрузка…
Ссылка в новой задаче